diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 572f4fda4f4e..bb5a9d33bf3c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1876,6 +1876,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer */ var paramIndex = Map[Name, Int]() + def containsParamRef(tree: untpd.Tree, params: List[untpd.ValDef]): Boolean = + import untpd.* + val acc = new UntypedTreeAccumulator[Boolean]: + def apply(x: Boolean, t: Tree)(using Context) = + if x then true + else t match + case _: untpd.TypedSplice => false + case Ident(name) => params.exists(_.name == name) + case _ => foldOver(x, t) + acc(false, tree) + /** Infer parameter type from the body of the function * * 1. If function is of the form @@ -1910,24 +1921,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer for (param <- params; idx <- paramIndices(param, args)) yield param.name -> idx }.toMap - if (paramIndex.size == params.length) + if (paramIndex.size == params.length) then expr match case untpd.TypedSplice(expr1) => expr1.tpe - case _ => + case _ if !containsParamRef(expr, params) => val outerCtx = ctx val nestedCtx = outerCtx.fresh.setNewTyperState() - inContext(nestedCtx) { - val protoArgs = args.map(_.withType(WildcardType)) + inContext(nestedCtx): + // try to type expr with fresh unknown arguments. + val protoArgs = args.map(arg => untpd.Ident(UniqueName.fresh()).withSpan(arg.span)) val callProto = FunProto(protoArgs, WildcardType)(this, app.applyKind) val expr1 = typedExpr(expr, callProto) if nestedCtx.reporter.hasErrors then NoType - else inContext(outerCtx) { + else inContext(outerCtx): nestedCtx.typerState.commit() fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args) expr1.tpe - } - } + case _ => + NoType else NoType case _ => NoType diff --git a/tests/pos/i24689b.scala b/tests/pos/i24689b.scala index a87c601a3352..d773ded8e8fa 100644 --- a/tests/pos/i24689b.scala +++ b/tests/pos/i24689b.scala @@ -9,11 +9,10 @@ object Test1: foo(transfer(_)) def transfer(in: B): Unit = ??? -// TODO: need to fix callee type in typedFunctionValue -// object Test2: -// def foo[T <: (B => Unit)](f: T) = ??? -// def transfer(in: A): Unit = -// foo(in => transfer(in)) -// foo(transfer) -// foo(transfer(_)) -// def transfer(in: B): Unit = ??? +object Test2: + def foo[T <: (B => Unit)](f: T) = ??? + def transfer(in: A): Unit = + foo(in => transfer(in)) + foo(transfer) + foo(transfer(_)) + def transfer(in: B): Unit = ???