Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
*/
var paramIndex = Map[Name, Int]()

def containsParamRef(tree: untpd.Tree, params: List[untpd.ValDef]): Boolean =
import untpd.*
val acc = new UntypedTreeAccumulator[Boolean]:
def apply(x: Boolean, t: Tree)(using Context) =
if x then true
else t match
case _: untpd.TypedSplice => false
case Ident(name) => params.exists(_.name == name)
case _ => foldOver(x, t)
acc(false, tree)

/** Infer parameter type from the body of the function
*
* 1. If function is of the form
Expand Down Expand Up @@ -1910,24 +1921,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
for (param <- params; idx <- paramIndices(param, args))
yield param.name -> idx
}.toMap
if (paramIndex.size == params.length)
if (paramIndex.size == params.length) then
expr match
case untpd.TypedSplice(expr1) =>
expr1.tpe
case _ =>
case _ if !containsParamRef(expr, params) =>
val outerCtx = ctx
val nestedCtx = outerCtx.fresh.setNewTyperState()
inContext(nestedCtx) {
val protoArgs = args.map(_.withType(WildcardType))
inContext(nestedCtx):
// try to type expr with fresh unknown arguments.
val protoArgs = args.map(arg => untpd.Ident(UniqueName.fresh()).withSpan(arg.span))
val callProto = FunProto(protoArgs, WildcardType)(this, app.applyKind)
val expr1 = typedExpr(expr, callProto)
if nestedCtx.reporter.hasErrors then NoType
else inContext(outerCtx) {
else inContext(outerCtx):
nestedCtx.typerState.commit()
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
expr1.tpe
}
}
case _ =>
NoType
else NoType
case _ =>
NoType
Expand Down
15 changes: 7 additions & 8 deletions tests/pos/i24689b.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ object Test1:
foo(transfer(_))
def transfer(in: B): Unit = ???

// TODO: need to fix callee type in typedFunctionValue
// object Test2:
// def foo[T <: (B => Unit)](f: T) = ???
// def transfer(in: A): Unit =
// foo(in => transfer(in))
// foo(transfer)
// foo(transfer(_))
// def transfer(in: B): Unit = ???
object Test2:
def foo[T <: (B => Unit)](f: T) = ???
def transfer(in: A): Unit =
foo(in => transfer(in))
foo(transfer)
foo(transfer(_))
def transfer(in: B): Unit = ???
Loading