Skip to content

Commit d5e74f7

Browse files
authored
Merge pull request #14134 from dotty-staging/pure-funs
Distinguish between pure and impure function types
2 parents cf5fa2c + bab020a commit d5e74f7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+594
-446
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
7070
case class InterpolatedString(id: TermName, segments: List[Tree])(implicit @constructorOnly src: SourceFile)
7171
extends TermTree
7272

73-
/** A function type */
73+
/** A function type or closure */
7474
case class Function(args: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
7575
override def isTerm: Boolean = body.isTerm
7676
override def isType: Boolean = body.isType
7777
}
7878

79-
/** A function type with `implicit`, `erased`, or `given` modifiers */
79+
/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
8080
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
8181
extends Function(args, body)
8282

@@ -217,6 +217,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
217217
case class Transparent()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Transparent)
218218

219219
case class Infix()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Infix)
220+
221+
case class Impure()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Impure)
220222
}
221223

222224
/** Modifiers and annotations for definitions

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@ def retainedElems(tree: Tree)(using Context): List[Tree] = tree match
1717
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
1818
case _ => Nil
1919

20+
class IllegalCaptureRef(tpe: Type) extends Exception
21+
2022
extension (tree: Tree)
2123

22-
def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef]
24+
def toCaptureRef(using Context): CaptureRef = tree.tpe match
25+
case ref: CaptureRef => ref
26+
case tpe => throw IllegalCaptureRef(tpe)
2327

2428
def toCaptureSet(using Context): CaptureSet =
2529
tree.getAttachment(Captures) match
@@ -59,20 +63,6 @@ extension (tp: Type)
5963

6064
def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty
6165

62-
def canHaveInferredCapture(using Context): Boolean = tp match
63-
case tp: TypeRef if tp.symbol.isClass =>
64-
!tp.symbol.isValueClass && tp.symbol != defn.AnyClass
65-
case _: TypeVar | _: TypeParamRef =>
66-
false
67-
case tp: TypeProxy =>
68-
tp.superType.canHaveInferredCapture
69-
case tp: AndType =>
70-
tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture
71-
case tp: OrType =>
72-
tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture
73-
case _ =>
74-
false
75-
7666
def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match
7767
case CapturingType(parent, _, _) =>
7868
parent.stripCapturing

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ sealed abstract class CaptureSet extends Showable:
5656
assert(v.isConst)
5757
Const(v.elems)
5858

59+
final def isUniversal(using Context) =
60+
elems.exists {
61+
case ref: TermRef => ref.symbol == defn.captureRoot
62+
case _ => false
63+
}
64+
5965
/** Cast to variable. @pre: !isConst */
6066
def asVar: Var =
6167
assert(!isConst)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,22 @@ object CapturingType:
1212
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
1313

1414
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
15-
if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then
15+
if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp)
16+
else None
17+
18+
end CapturingType
19+
20+
object EventuallyCapturingType:
21+
22+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
23+
if tp.annot.symbol == defn.RetainsAnnot then
1624
tp.annot match
1725
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
18-
case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
26+
case ann =>
27+
try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
28+
catch case ex: IllegalCaptureRef => None
1929
else None
2030

21-
end CapturingType
31+
end EventuallyCapturingType
32+
33+

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 82 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Flags._, Scopes._, Decorators._, NameOps._, Periods._, NullOpsDecorator._
88
import unpickleScala2.Scala2Unpickler.ensureConstructor
99
import scala.collection.mutable
1010
import collection.mutable
11-
import Denotations.SingleDenotation
11+
import Denotations.{SingleDenotation, staticRef}
1212
import util.{SimpleIdentityMap, SourceFile, NoSource}
1313
import typer.ImportInfo.RootRef
1414
import Comments.CommentsContext
@@ -86,7 +86,7 @@ class Definitions {
8686
*
8787
* FunctionN traits follow this template:
8888
*
89-
* trait FunctionN[T0,...T{N-1}, R] extends Object {
89+
* trait FunctionN[-T0,...-T{N-1}, +R] extends Object {
9090
* def apply($x0: T0, ..., $x{N_1}: T{N-1}): R
9191
* }
9292
*
@@ -96,46 +96,65 @@ class Definitions {
9696
*
9797
* ContextFunctionN traits follow this template:
9898
*
99-
* trait ContextFunctionN[T0,...,T{N-1}, R] extends Object {
99+
* trait ContextFunctionN[-T0,...,-T{N-1}, +R] extends Object {
100100
* def apply(using $x0: T0, ..., $x{N_1}: T{N-1}): R
101101
* }
102102
*
103103
* ErasedFunctionN traits follow this template:
104104
*
105-
* trait ErasedFunctionN[T0,...,T{N-1}, R] extends Object {
105+
* trait ErasedFunctionN[-T0,...,-T{N-1}, +R] extends Object {
106106
* def apply(erased $x0: T0, ..., $x{N_1}: T{N-1}): R
107107
* }
108108
*
109109
* ErasedContextFunctionN traits follow this template:
110110
*
111-
* trait ErasedContextFunctionN[T0,...,T{N-1}, R] extends Object {
111+
* trait ErasedContextFunctionN[-T0,...,-T{N-1}, +R] extends Object {
112112
* def apply(using erased $x0: T0, ..., $x{N_1}: T{N-1}): R
113113
* }
114114
*
115115
* ErasedFunctionN and ErasedContextFunctionN erase to Function0.
116+
*
117+
* EffXYZFunctionN afollow this template:
118+
*
119+
* type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
116120
*/
117-
def newFunctionNTrait(name: TypeName): ClassSymbol = {
121+
private def newFunctionNType(name: TypeName): Symbol = {
122+
val impure = name.startsWith("Impure")
118123
val completer = new LazyType {
119124
def complete(denot: SymDenotation)(using Context): Unit = {
120-
val cls = denot.asClass.classSymbol
121-
val decls = newScope
122125
val arity = name.functionArity
123-
val paramNamePrefix = tpnme.scala ++ str.NAME_JOIN ++ name ++ str.EXPAND_SEPARATOR
124-
val argParamRefs = List.tabulate(arity) { i =>
125-
enterTypeParam(cls, paramNamePrefix ++ "T" ++ (i + 1).toString, Contravariant, decls).typeRef
126-
}
127-
val resParamRef = enterTypeParam(cls, paramNamePrefix ++ "R", Covariant, decls).typeRef
128-
val methodType = MethodType.companion(
129-
isContextual = name.isContextFunction,
130-
isImplicit = false,
131-
isErased = name.isErasedFunction)
132-
decls.enter(newMethod(cls, nme.apply, methodType(argParamRefs, resParamRef), Deferred))
133-
denot.info =
134-
ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls)
126+
if impure then
127+
val argParamNames = List.tabulate(arity)(tpnme.syntheticTypeParamName)
128+
val argVariances = List.fill(arity)(Contravariant)
129+
val underlyingName = name.asSimpleName.drop(6)
130+
val underlyingClass = ScalaPackageVal.requiredClass(underlyingName)
131+
denot.info = TypeAlias(
132+
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
133+
tl => List.fill(arity + 1)(TypeBounds.empty),
134+
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
135+
CaptureSet.universal, boxed = false)
136+
))
137+
else
138+
val cls = denot.asClass.classSymbol
139+
val decls = newScope
140+
val paramNamePrefix = tpnme.scala ++ str.NAME_JOIN ++ name ++ str.EXPAND_SEPARATOR
141+
val argParamRefs = List.tabulate(arity) { i =>
142+
enterTypeParam(cls, paramNamePrefix ++ "T" ++ (i + 1).toString, Contravariant, decls).typeRef
143+
}
144+
val resParamRef = enterTypeParam(cls, paramNamePrefix ++ "R", Covariant, decls).typeRef
145+
val methodType = MethodType.companion(
146+
isContextual = name.isContextFunction,
147+
isImplicit = false,
148+
isErased = name.isErasedFunction)
149+
decls.enter(newMethod(cls, nme.apply, methodType(argParamRefs, resParamRef), Deferred))
150+
denot.info =
151+
ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls)
135152
}
136153
}
137-
val flags = Trait | NoInits
138-
newPermanentClassSymbol(ScalaPackageClass, name, flags, completer)
154+
if impure then
155+
newPermanentSymbol(ScalaPackageClass, name, EmptyFlags, completer)
156+
else
157+
newPermanentClassSymbol(ScalaPackageClass, name, Trait | NoInits, completer)
139158
}
140159

141160
private def newMethod(cls: ClassSymbol, name: TermName, info: Type, flags: FlagSet = EmptyFlags): TermSymbol =
@@ -209,7 +228,7 @@ class Definitions {
209228
val cls = ScalaPackageVal.moduleClass.asClass
210229
cls.info.decls.openForMutations.useSynthesizer(
211230
name =>
212-
if (name.isTypeName && name.isSyntheticFunction) newFunctionNTrait(name.asTypeName)
231+
if (name.isTypeName && name.isSyntheticFunction) newFunctionNType(name.asTypeName)
213232
else NoSymbol)
214233
cls
215234
}
@@ -1273,37 +1292,54 @@ class Definitions {
12731292

12741293
@tu lazy val TupleType: Array[TypeRef] = mkArityArray("scala.Tuple", MaxTupleArity, 1)
12751294

1295+
/** Cached function types of arbitary arities.
1296+
* Function types are created on demand with newFunctionNTrait, which is
1297+
* called from a synthesizer installed in ScalaPackageClass.
1298+
*/
12761299
private class FunType(prefix: String):
12771300
private var classRefs: Array[TypeRef] = new Array(22)
1301+
12781302
def apply(n: Int): TypeRef =
12791303
while n >= classRefs.length do
12801304
val classRefs1 = new Array[TypeRef](classRefs.length * 2)
12811305
Array.copy(classRefs, 0, classRefs1, 0, classRefs.length)
12821306
classRefs = classRefs1
1307+
val funName = s"scala.$prefix$n"
12831308
if classRefs(n) == null then
1284-
classRefs(n) = requiredClassRef(prefix + n.toString)
1309+
classRefs(n) =
1310+
if prefix.startsWith("Impure")
1311+
then staticRef(funName.toTypeName).symbol.typeRef
1312+
else requiredClassRef(funName)
12851313
classRefs(n)
1286-
1287-
private val erasedContextFunType = FunType("scala.ErasedContextFunction")
1288-
private val contextFunType = FunType("scala.ContextFunction")
1289-
private val erasedFunType = FunType("scala.ErasedFunction")
1290-
private val funType = FunType("scala.Function")
1291-
1292-
def FunctionClass(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): Symbol =
1293-
( if isContextual && isErased then erasedContextFunType(n)
1294-
else if isContextual then contextFunType(n)
1295-
else if isErased then erasedFunType(n)
1296-
else funType(n)
1297-
).symbol.asClass
1314+
end FunType
1315+
1316+
private def funTypeIdx(isContextual: Boolean, isErased: Boolean, isImpure: Boolean): Int =
1317+
(if isContextual then 1 else 0)
1318+
+ (if isErased then 2 else 0)
1319+
+ (if isImpure then 4 else 0)
1320+
1321+
private val funTypeArray: IArray[FunType] =
1322+
val arr = Array.ofDim[FunType](8)
1323+
val choices = List(false, true)
1324+
for contxt <- choices; erasd <- choices; impure <- choices do
1325+
var str = "Function"
1326+
if contxt then str = "Context" + str
1327+
if erasd then str = "Erased" + str
1328+
if impure then str = "Impure" + str
1329+
arr(funTypeIdx(contxt, erasd, impure)) = FunType(str)
1330+
IArray.unsafeFromArray(arr)
1331+
1332+
def FunctionSymbol(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): Symbol =
1333+
funTypeArray(funTypeIdx(isContextual, isErased, isImpure))(n).symbol
12981334

12991335
@tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply)
13001336

1301-
@tu lazy val Function0: Symbol = FunctionClass(0)
1302-
@tu lazy val Function1: Symbol = FunctionClass(1)
1303-
@tu lazy val Function2: Symbol = FunctionClass(2)
1337+
@tu lazy val Function0: Symbol = FunctionSymbol(0)
1338+
@tu lazy val Function1: Symbol = FunctionSymbol(1)
1339+
@tu lazy val Function2: Symbol = FunctionSymbol(2)
13041340

1305-
def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef =
1306-
FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef
1341+
def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): TypeRef =
1342+
FunctionSymbol(n, isContextual && !ctx.erasedTypes, isErased, isImpure).typeRef
13071343

13081344
lazy val PolyFunctionClass = requiredClass("scala.PolyFunction")
13091345
def PolyFunctionType = PolyFunctionClass.typeRef
@@ -1345,6 +1381,10 @@ class Definitions {
13451381
*/
13461382
def isFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isFunction
13471383

1384+
/** Is a function class, or an impure function type alias */
1385+
def isFunctionSymbol(sym: Symbol): Boolean =
1386+
sym.isType && (sym.owner eq ScalaPackageClass) && sym.name.isFunction
1387+
13481388
/** Is a function class where
13491389
* - FunctionN for N >= 0 and N != XXL
13501390
*/
@@ -1550,7 +1590,7 @@ class Definitions {
15501590
new PerRun(Function2SpecializedReturnTypes.map(_.symbol))
15511591

15521592
def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(using Context): Boolean =
1553-
paramTypes.length <= 2 && cls.derivesFrom(FunctionClass(paramTypes.length))
1593+
paramTypes.length <= 2 && cls.derivesFrom(FunctionSymbol(paramTypes.length))
15541594
&& isSpecializableFunctionSAM(paramTypes, retType)
15551595

15561596
/** If the Single Abstract Method of a Function class has this type, is it specializable? */

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ object Flags {
314314
/** A Scala 2x super accessor / an unpickled Scala 2.x class */
315315
val (SuperParamAliasOrScala2x @ _, SuperParamAlias @ _, Scala2x @ _) = newFlags(26, "<super-param-alias>", "<scala-2.x>")
316316

317-
/** A parameter with a default value */
318-
val (_, HasDefault @ _, _) = newFlags(27, "<hasdefault>")
317+
/** A parameter with a default value / an impure untpd.Function type */
318+
val (_, HasDefault @ _, Impure @ _) = newFlags(27, "<hasdefault>", "<{*}>")
319319

320320
/** An extension method, or a collective extension instance */
321321
val (Extension @ _, ExtensionMethod @ _, _) = newFlags(28, "<extension>")

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,25 @@ object NameOps {
197197
else collectDigits(acc * 10 + d, idx + 1)
198198
collectDigits(0, suffixStart + 8)
199199

200-
/** name[0..suffixStart) == `str` */
201-
private def isPreceded(str: String, suffixStart: Int) =
202-
str.length == suffixStart && name.firstPart.startsWith(str)
200+
private def isFunctionPrefix(suffixStart: Int, mustHave: String = ""): Boolean =
201+
suffixStart >= 0
202+
&& {
203+
val first = name.firstPart
204+
var found = mustHave.isEmpty
205+
def skip(idx: Int, str: String) =
206+
if first.startsWith(str, idx) then
207+
if str == mustHave then found = true
208+
idx + str.length
209+
else idx
210+
skip(skip(skip(0, "Impure"), "Erased"), "Context") == suffixStart
211+
&& found
212+
}
203213

204214
/** Same as `funArity`, except that it returns -1 if the prefix
205215
* is not one of "", "Context", "Erased", "ErasedContext"
206216
*/
207217
private def checkedFunArity(suffixStart: Int): Int =
208-
if suffixStart == 0
209-
|| isPreceded("Context", suffixStart)
210-
|| isPreceded("Erased", suffixStart)
211-
|| isPreceded("ErasedContext", suffixStart)
212-
then funArity(suffixStart)
213-
else -1
218+
if isFunctionPrefix(suffixStart) then funArity(suffixStart) else -1
214219

215220
/** Is a function name, i.e one of FunctionXXL, FunctionN, ContextFunctionN, ErasedFunctionN, ErasedContextFunctionN for N >= 0
216221
*/
@@ -222,19 +227,14 @@ object NameOps {
222227
*/
223228
def isPlainFunction: Boolean = functionArity >= 0
224229

225-
/** Is an context function name, i.e one of ContextFunctionN or ErasedContextFunctionN for N >= 0
226-
*/
227-
def isContextFunction: Boolean =
230+
/** Is a function name that contains `mustHave` as a substring */
231+
private def isSpecificFunction(mustHave: String): Boolean =
228232
val suffixStart = functionSuffixStart
229-
(isPreceded("Context", suffixStart) || isPreceded("ErasedContext", suffixStart))
230-
&& funArity(suffixStart) >= 0
233+
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0
231234

232-
/** Is an erased function name, i.e. one of ErasedFunctionN, ErasedContextFunctionN for N >= 0
233-
*/
234-
def isErasedFunction: Boolean =
235-
val suffixStart = functionSuffixStart
236-
(isPreceded("Erased", suffixStart) || isPreceded("ErasedContext", suffixStart))
237-
&& funArity(suffixStart) >= 0
235+
def isContextFunction: Boolean = isSpecificFunction("Context")
236+
def isErasedFunction: Boolean = isSpecificFunction("Erased")
237+
def isImpureFunction: Boolean = isSpecificFunction("Impure")
238238

239239
/** Is a synthetic function name, i.e. one of
240240
* - FunctionN for N > 22

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ object StdNames {
741741
val XOR : N = "^"
742742
val ZAND : N = "&&"
743743
val ZOR : N = "||"
744+
val PUREARROW: N = "->"
745+
val PURECTXARROW: N = "?->"
744746

745747
// unary operators
746748
val UNARY_PREFIX: N = "unary_"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2413,7 +2413,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
24132413
case tp1: TypeVar if tp1.isInstantiated =>
24142414
tp1.underlying & tp2
24152415
case CapturingType(parent1, refs1, _) =>
2416-
if subCaptures(tp2.captureSet, refs1, frozenConstraint).isOK then
2416+
if subCaptures(tp2.captureSet, refs1, frozen = true).isOK then
24172417
parent1 & tp2
24182418
else
24192419
tp1.derivedCapturingType(parent1 & tp2, refs1)

0 commit comments

Comments
 (0)