Skip to content

Commit 8ec5c4e

Browse files
committed
Cleanup transformation of inferred types
1 parent 8e96a9f commit 8ec5c4e

File tree

5 files changed

+143
-94
lines changed

5 files changed

+143
-94
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,6 @@ extension (tp: Type)
6363

6464
def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty
6565

66-
def canHaveInferredCapture(using Context): Boolean = tp match
67-
case tp: TypeRef if tp.symbol.isClass =>
68-
!tp.symbol.isValueClass && tp.symbol != defn.AnyClass
69-
case _: TypeVar | _: TypeParamRef =>
70-
false
71-
case tp: TypeProxy =>
72-
tp.superType.canHaveInferredCapture
73-
case tp: AndType =>
74-
tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture
75-
case tp: OrType =>
76-
tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture
77-
case _ =>
78-
false
79-
8066
def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match
8167
case CapturingType(parent, _, _) =>
8268
parent.stripCapturing

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 136 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -117,85 +117,150 @@ class CheckCaptures extends Recheck:
117117

118118
override def transformType(tp: Type, inferred: Boolean, boxed: Boolean)(using Context): Type =
119119

120-
def addInnerVars(tp: Type): Type = tp match
121-
case tp @ AppliedType(tycon, args) =>
122-
tp.derivedAppliedType(tycon, args.map(addVars(_, boxed = true)))
123-
case tp @ RefinedType(core, rname, rinfo) =>
124-
val rinfo1 = addVars(rinfo)
125-
if defn.isFunctionType(tp) then
126-
rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
127-
else
128-
tp.derivedRefinedType(addInnerVars(core), rname, rinfo1)
129-
case tp: MethodType =>
130-
tp.derivedLambdaType(
131-
paramInfos = tp.paramInfos.mapConserve(addVars(_)),
132-
resType = addVars(tp.resType))
133-
case tp: PolyType =>
134-
tp.derivedLambdaType(
135-
resType = addVars(tp.resType))
136-
case tp: ExprType =>
137-
tp.derivedExprType(resType = addVars(tp.resType))
138-
case _ =>
139-
tp
140-
141-
/** Turn plain function types into dependent function types, so that
142-
* we can refer to the parameter in capture sets
120+
def depFun(tycon: Type, argTypes: List[Type], resType: Type): Type =
121+
MethodType.companion(
122+
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
123+
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
124+
)(argTypes, resType)
125+
.toFunctionType(isJava = false, alwaysDependent = true)
126+
127+
def box(tp: Type): Type = tp match
128+
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
129+
case _ => tp
130+
131+
/** Perform the following transformation steps everywhere in a type:
132+
* 1. Drop retains annotations
133+
* 2. Turn plain function types into dependent function types, so that
134+
* we can refer to their parameter in capture sets. Currently this is
135+
* only done at the toplevel, i.e. for function types that are not
136+
* themselves argument types of other function types. Without this restriction
137+
* boxmap-paper.scala fails. Need to figure out why.
138+
* 3. Refine other class types C by adding capture set variables to their parameter getters
139+
* (see addCaptureRefinements)
140+
* 4. Add capture set variables to all types that can be tracked
141+
*
142+
* Polytype bounds are only cleaned using step 1, but not otherwise transformed.
143143
*/
144-
def addFunctionRefinements(tp: Type): Type = tp match
145-
case tp @ AppliedType(tycon, args) =>
146-
if defn.isNonRefinedFunction(tp) then
147-
MethodType.companion(
148-
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
149-
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
150-
)(args.init, addFunctionRefinements(args.last))
151-
.toFunctionType(isJava = false, alwaysDependent = true)
152-
.showing(i"add function refinement $tp --> $result", capt)
153-
else
154-
tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_)))
155-
case tp @ RefinedType(core, rname, rinfo) if !defn.isFunctionType(tp) =>
156-
tp.derivedRefinedType(
157-
addFunctionRefinements(core), rname, addFunctionRefinements(rinfo))
158-
case tp: MethodOrPoly =>
159-
tp.derivedLambdaType(resType = addFunctionRefinements(tp.resType))
160-
case tp: ExprType =>
161-
tp.derivedExprType(resType = addFunctionRefinements(tp.resType))
162-
case _ =>
163-
tp
144+
def mapInferred = new TypeMap:
164145

165-
/** Refine a possibly applied class type C where the class has tracked parameters
166-
* x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n }
167-
* where CV_1, ..., CV_n are fresh capture sets.
168-
*/
169-
def addCaptureRefinements(tp: Type): Type = tp.stripped match
170-
case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass =>
171-
val cls = tp.typeSymbol.asClass
172-
cls.paramGetters.foldLeft(tp) { (core, getter) =>
173-
if getter.termRef.isTracked then
174-
val getterType = tp.memberInfo(getter).strippedDealias
175-
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
176-
.showing(i"add capture refinement $tp --> $result", capt)
177-
else
178-
core
179-
}
180-
case _ =>
181-
tp
146+
/** Drop @retains annotations everywhere */
147+
object cleanup extends TypeMap:
148+
def apply(t: Type) = t match
149+
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
150+
apply(parent)
151+
case _ =>
152+
mapOver(t)
182153

183-
def addVars(tp: Type, boxed: Boolean = false): Type =
184-
var tp1 = addInnerVars(tp)
185-
val tp2 = addCaptureRefinements(tp1)
186-
if tp1.canHaveInferredCapture
187-
then CapturingType(tp2, CaptureSet.Var(), boxed)
188-
else tp2
154+
/** Refine a possibly applied class type C where the class has tracked parameters
155+
* x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n }
156+
* where CV_1, ..., CV_n are fresh capture sets.
157+
*/
158+
def addCaptureRefinements(tp: Type): Type = tp match
159+
case _: TypeRef | _: AppliedType if tp.typeParams.isEmpty =>
160+
tp.typeSymbol match
161+
case cls: ClassSymbol if !defn.isFunctionClass(cls) =>
162+
cls.paramGetters.foldLeft(tp) { (core, getter) =>
163+
if getter.termRef.isTracked then
164+
val getterType = tp.memberInfo(getter).strippedDealias
165+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
166+
.showing(i"add capture refinement $tp --> $result", capt)
167+
else
168+
core
169+
}
170+
case _ => tp
171+
case _ => tp
172+
173+
/** Should a capture set variable be added on type `tp`? */
174+
def canHaveInferredCapture(tp: Type): Boolean =
175+
tp.typeParams.isEmpty && tp.match
176+
case tp: (TypeRef | AppliedType) =>
177+
val sym = tp.typeSymbol
178+
if sym.isClass then !sym.isValueClass && sym != defn.AnyClass
179+
else canHaveInferredCapture(tp.superType.dealias)
180+
case tp: (RefinedOrRecType | MatchType) =>
181+
canHaveInferredCapture(tp.underlying)
182+
case tp: AndType =>
183+
canHaveInferredCapture(tp.tp1) && canHaveInferredCapture(tp.tp2)
184+
case tp: OrType =>
185+
canHaveInferredCapture(tp.tp1) || canHaveInferredCapture(tp.tp2)
186+
case _ =>
187+
false
188+
189+
/** Add a capture set variable to `tp` if necessary, or maybe pull out
190+
* an embedded capture set variables from a part of `tp`.
191+
*/
192+
def addVar(tp: Type) = tp match
193+
case tp @ RefinedType(parent @ CapturingType(parent1, refs, boxed), rname, rinfo) =>
194+
CapturingType(tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed)
195+
case tp: RecType =>
196+
tp.parent match
197+
case CapturingType(parent1, refs, boxed) =>
198+
CapturingType(tp.derivedRecType(parent1), refs, boxed)
199+
case _ =>
200+
tp // can return `tp` here since unlike RefinedTypes, RecTypes are never created
201+
// by `mapInferred`. Hence if the underlying type admits capture variables
202+
// a variable was already added, and the first case above would apply.
203+
case AndType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) =>
204+
assert(refs1.asVar.elems.isEmpty)
205+
assert(refs2.asVar.elems.isEmpty)
206+
assert(boxed1 == boxed2)
207+
CapturingType(AndType(parent1, parent2), refs1, boxed1)
208+
case tp @ OrType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) =>
209+
assert(refs1.asVar.elems.isEmpty)
210+
assert(refs2.asVar.elems.isEmpty)
211+
assert(boxed1 == boxed2)
212+
CapturingType(OrType(parent1, parent2, tp.isSoft), refs1, boxed1)
213+
case tp @ OrType(CapturingType(parent1, refs1, boxed1), tp2) =>
214+
CapturingType(OrType(parent1, tp2, tp.isSoft), refs1, boxed1)
215+
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
216+
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
217+
case _ if canHaveInferredCapture(tp) =>
218+
CapturingType(tp, CaptureSet.Var(), boxed = false)
219+
case _ =>
220+
tp
189221

190-
if inferred then
191-
val cleanup = new TypeMap:
192-
def apply(t: Type) = t match
222+
var isTopLevel = true
223+
224+
def mapNested(ts: List[Type]): List[Type] =
225+
val saved = isTopLevel
226+
isTopLevel = false
227+
try ts.mapConserve(this) finally isTopLevel = saved
228+
229+
def apply(t: Type) =
230+
val t1 = t match
193231
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
194232
apply(parent)
233+
case tp @ AppliedType(tycon, args) =>
234+
val tycon1 = this(tycon)
235+
if defn.isNonRefinedFunction(tp) then
236+
val args1 = mapNested(args.init)
237+
val res1 = this(args.last)
238+
if isTopLevel then
239+
depFun(tycon1, args1, res1)
240+
.showing(i"add function refinement $tp --> $result", capt)
241+
else
242+
tp.derivedAppliedType(tycon1, args1 :+ res1)
243+
else
244+
tp.derivedAppliedType(tycon1, args.mapConserve(arg => box(this(arg))))
245+
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
246+
apply(rinfo).toFunctionType(isJava = false, alwaysDependent = true)
247+
case tp: MethodType =>
248+
tp.derivedLambdaType(
249+
paramInfos = mapNested(tp.paramInfos),
250+
resType = this(tp.resType))
251+
case tp: TypeLambda =>
252+
// Don't recurse into parameter bounds, just cleanup any stray retains annotations
253+
tp.derivedLambdaType(
254+
paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds),
255+
resType = this(tp.resType))
195256
case _ =>
196257
mapOver(t)
197-
addVars(addFunctionRefinements(cleanup(tp)), boxed)
198-
.showing(i"reinfer $tp --> $result", capt)
258+
addVar(addCaptureRefinements(t1))
259+
end mapInferred
260+
261+
if inferred then
262+
val tp1 = mapInferred(tp)
263+
if boxed then box(tp1) else tp1
199264
else
200265
def setBoxed(t: Type) = t match
201266
case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot =>

tests/neg-custom-args/captures/capt1.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ longer explanation available when compiling with `-explain`
4040
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ----------------------------------------
4141
31 | val z2 = h[() -> Cap](() => x)(() => C()) // error
4242
| ^^^^^^^
43-
| Found: {x} () -> ? Cap
43+
| Found: {x} () -> Cap
4444
| Required: () -> Cap
4545

4646
longer explanation available when compiling with `-explain`

tests/pos-custom-args/captures/lazyref.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
class CC
2-
type Cap = {*} CC
1+
@annotation.capability class Cap
32

4-
class LazyRef[T](val elem: {*} () => T):
3+
class LazyRef[T](val elem: () => T):
54
val get = elem
6-
def map[U](f: {*} T => U): {f, this} LazyRef[U] =
5+
def map[U](f: T => U): {f, this} LazyRef[U] =
76
new LazyRef(() => f(elem()))
87

9-
def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] =
8+
def map[A, B](ref: {*} LazyRef[A], f: A => B): {f, ref} LazyRef[B] =
109
new LazyRef(() => f(ref.elem()))
1110

12-
def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] =
11+
def mapc[A, B]: (ref: {*} LazyRef[A], f: A => B) => {f, ref} LazyRef[B] =
1312
(ref1, f1) => map[A, B](ref1, f1)
1413

1514
def test(cap1: Cap, cap2: Cap) =

tests/pos-custom-args/captures/pairs.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

2-
class C
3-
type Cap = {*} C
2+
@annotation.capability class Cap
43

54
object Generic:
65

0 commit comments

Comments
 (0)