Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ extension (tp: Type)
case _ =>
tp

def isCapturingType(using Context): Boolean =
tp match
case CapturingType(_, _) => true
case _ => false

extension (sym: Symbol)

/** Does this symbol allow results carrying the universal capability?
Expand Down
182 changes: 143 additions & 39 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,20 @@ object CheckCaptures:
end Pre

/** A class describing environments.
* @param owner the current owner
* @param captured the caputure set containing all references to tracked free variables outside of boxes
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
* @param outer0 the next enclosing environment
* @param owner the current owner
* @param nestedInOwner true if the environment is a temporary one nested in the owner's environment,
* and does not have an actual owner symbol (this happens when doing box adaptation).
* @param captured the caputure set containing all references to tracked free variables outside of boxes
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
* @param outer0 the next enclosing environment
*/
case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer0: Env | Null):
case class Env(
owner: Symbol,
nestedInOwner: Boolean,
captured: CaptureSet,
isBoxed: Boolean,
outer0: Env | Null
):
def outer = outer0.nn

def isOutermost = outer0 == null
Expand Down Expand Up @@ -204,7 +212,7 @@ class CheckCaptures extends Recheck, SymTransformer:
report.error(i"$header included in allowed capture set ${res.blocking}", pos)

/** The current environment */
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, isBoxed = false, null)
private var curEnv: Env = Env(NoSymbol, false, CaptureSet.empty, isBoxed = false, null)

private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap()

Expand Down Expand Up @@ -249,8 +257,12 @@ class CheckCaptures extends Recheck, SymTransformer:
if !cs.isAlwaysEmpty then
forallOuterEnvsUpTo(ctx.owner.topLevelClass) { env =>
val included = cs.filter {
case ref: TermRef => env.owner.isProperlyContainedIn(ref.symbol.owner)
case ref: ThisType => env.owner.isProperlyContainedIn(ref.cls)
case ref: TermRef =>
(env.nestedInOwner || env.owner != ref.symbol.owner)
&& env.owner.isContainedIn(ref.symbol.owner)
case ref: ThisType =>
(env.nestedInOwner || env.owner != ref.cls)
&& env.owner.isContainedIn(ref.cls)
case _ => false
}
capt.println(i"Include call capture $included in ${env.owner}")
Expand Down Expand Up @@ -439,7 +451,7 @@ class CheckCaptures extends Recheck, SymTransformer:
if !Synthetics.isExcluded(sym) then
val saved = curEnv
val localSet = capturedVars(sym)
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, isBoxed = false, curEnv)
if !localSet.isAlwaysEmpty then curEnv = Env(sym, false, localSet, isBoxed = false, curEnv)
try super.recheckDefDef(tree, sym)
finally
interpolateVarsIn(tree.tpt)
Expand All @@ -455,7 +467,7 @@ class CheckCaptures extends Recheck, SymTransformer:
val localSet = capturedVars(cls)
for parent <- impl.parents do // (1)
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, isBoxed = false, curEnv)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, false, localSet, isBoxed = false, curEnv)
try
val thisSet = cls.classInfo.selfType.captureSet.withDescription(i"of the self type of $cls")
checkSubset(localSet, thisSet, tree.srcPos) // (2)
Expand Down Expand Up @@ -502,7 +514,7 @@ class CheckCaptures extends Recheck, SymTransformer:
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
if tree.isTerm && pt.isBoxedCapturing then
val saved = curEnv
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = true, curEnv)
curEnv = Env(curEnv.owner, false, CaptureSet.Var(), isBoxed = true, curEnv)
try super.recheck(tree, pt)
finally curEnv = saved
else
Expand Down Expand Up @@ -593,25 +605,124 @@ class CheckCaptures extends Recheck, SymTransformer:

/** Adapt function type `actual`, which is `aargs -> ares` (possibly with dependencies)
* to `expected` type.
* It returns the adapted type along with the additionally captured variable
* during adaptation.
* @param reconstruct how to rebuild the adapted function type
*/
def adaptFun(actual: Type, aargs: List[Type], ares: Type, expected: Type,
covariant: Boolean,
reconstruct: (List[Type], Type) => Type): Type =
val (eargs, eres) = expected.dealias match
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
case _ => (aargs.map(_ => WildcardType), WildcardType)
val aargs1 = aargs.zipWithConserve(eargs)(adapt(_, _, !covariant))
val ares1 = adapt(ares, eres, covariant)
if (ares1 eq ares) && (aargs1 eq aargs) then actual
else reconstruct(aargs1, ares1)

def adapt(actual: Type, expected: Type, covariant: Boolean): Type = actual.dealias match
case actual @ CapturingType(parent, refs) =>
val parent1 = adapt(parent, expected, covariant)
if actual.isBoxed != expected.isBoxedCapturing then
covariant: Boolean, boxed: Boolean,
reconstruct: (List[Type], Type) => Type): (Type, CaptureSet) =
val saved = curEnv
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)

try
val (eargs, eres) = expected.dealias match
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
case expected => expected.stripped match
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need a separate match on expected.stripped?. Can't you match everything with expected.dealias?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why stripped is used here is because the expected type can have a capture set but dealias does not drop the capturing annotation. However, I think what we want here is expected.dealias.stripCapturing, just like what we do for type functions. I have changed this pattern matching in both adaptFun and adaptTypeFun to make them match with each other and have a better rationale.

case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
case _ =>
(aargs.map(_ => WildcardType), WildcardType)
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
val ares1 = adapt(ares, eres, covariant)

val resTp =
if (ares1 eq ares) && (aargs1 eq aargs) then actual
else reconstruct(aargs1, ares1)

curEnv.captured.asVar.markSolved()
(resTp, curEnv.captured)
finally
curEnv = saved

/** Adapt type function type `actual` to the expected type.
* @see [[adaptFun]]
*/
def adaptTypeFun(
actual: Type, ares: Type, expected: Type,
covariant: Boolean, boxed: Boolean,
reconstruct: Type => Type): (Type, CaptureSet) =
val saved = curEnv
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)

try
val eres = expected.dealias.stripCapturing match
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
case _ => WildcardType

val ares1 = adapt(ares, eres, covariant)

val resTp =
if ares1 eq ares then actual
else reconstruct(ares1)

curEnv.captured.asVar.markSolved()
(resTp, curEnv.captured)
finally
curEnv = saved
end adaptTypeFun

def adaptInfo(actual: Type, expected: Type, covariant: Boolean): String =
val arrow = if covariant then "~~>" else "<~~"
i"adapting $actual $arrow $expected"

/** Destruct a capturing type `tp` to a tuple (cs, tp0, boxed),
* where `tp0` is not a capturing type.
*
* If `tp` is a nested capturing type, the return tuple always represents
* the innermost capturing type. The outer capture annotations can be
* reconstructed with the returned function.
*/
def destructCapturingType(tp: Type, reconstruct: Type => Type = x => x): ((Type, CaptureSet, Boolean), Type => Type) =
tp.dealias match
case tp @ CapturingType(parent, cs) =>
if parent.dealias.isCapturingType then
destructCapturingType(parent, res => reconstruct(tp.derivedCapturingType(res, cs)))
else
((parent, cs, tp.isBoxed), reconstruct)
case actual =>
((actual, CaptureSet(), false), reconstruct)

def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
if expected.isInstanceOf[WildcardType] then actual
else
val ((parent, cs, actualIsBoxed), recon) = destructCapturingType(actual)

val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
val insertBox = needsAdaptation && covariant != actualIsBoxed

val (parent1, cs1) = parent match {
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
val (parent1, cs1) = adaptFun(parent, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
(parent1, cs1 ++ cs)
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
val (parent1, cs1) = adaptFun(parent, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
.toFunctionType(isJava = false, alwaysDependent = true))
(parent1, cs1 ++ cs)
case actual: MethodType =>
val (parent1, cs1) = adaptFun(parent, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
(parent1, cs1 ++ cs)
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
val (parent1, cs1) = adaptTypeFun(parent, rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
actual1
)
(parent1, cs1 ++ cs)
case _ =>
(parent, cs)
}

if needsAdaptation then
val criticalSet = // the set which is not allowed to have `*`
if covariant then refs // can't box with `*`
if covariant then cs1 // can't box with `*`
else expected.captureSet // can't unbox with `*`
if criticalSet.isUniversal then
// We can't box/unbox the universal capability. Leave `actual` as it is
Expand All @@ -627,20 +738,13 @@ class CheckCaptures extends Recheck, SymTransformer:
|since one of their capture sets contains the root capability `*`""",
pos)
}
if covariant == actual.isBoxed then markFree(refs, pos)
CapturingType(parent1, refs, boxed = !actual.isBoxed)
if !insertBox then // unboxing
markFree(criticalSet, pos)
recon(CapturingType(parent1, cs1, !actualIsBoxed))
else
actual.derivedCapturingType(parent1, refs)
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant,
(aargs1, ares1) =>
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
.toFunctionType(isJava = false, alwaysDependent = true))
case _ => actual
recon(CapturingType(parent1, cs1, actualIsBoxed))
}


var actualw = actual.widenDealias
actual match
Expand Down
38 changes: 38 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-boxing.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
trait Cap

def main(io: {*} Cap, fs: {*} Cap): Unit = {
val test1: {} Unit -> Unit = _ => { // error
type Op = [T] -> ({io} T -> Unit) -> Unit
val f: ({io} Cap) -> Unit = ???
val op: Op = ???
op[{io} Cap](f)
// expected type of f: {io} (box {io} Cap) -> Unit
// actual type: ({io} Cap) -> Unit
// adapting f to the expected type will also
// charge the environment with {io}
}

val test2: {} Unit -> Unit = _ => {
type Box[X] = X
type Op0[X] = Box[X] -> Unit
type Op1[X] = Unit -> Box[X]
val f: Unit -> ({io} Cap) -> Unit = ???
val test: {} Op1[{io} Op0[{io} Cap]] = f
// expected: {} Unit -> box {io} (box {io} Cap) -> Unit
// actual: Unit -> ({io} Cap) -> Unit
//
// although adapting `({io} Cap) -> Unit` to
// `box {io} (box {io} Cap) -> Unit` will leak the
// captured variables {io}, but since it is inside a box,
// we will charge neither the outer type nor the environment
}

val test3 = {
type Box[X] = X
type Id[X] = Box[X] -> Unit
type Op[X] = Unit -> Box[X]
val f: Unit -> ({io} Cap) -> Unit = ???
val g: Op[{fs} Id[{io} Cap]] = f // error
val h: {} Op[{io} Id[{io} Cap]] = f
}
}
29 changes: 29 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-cases.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
trait Cap { def use(): Int }

def test1(): Unit = {
type Id[X] = [T] -> (op: X => T) -> T

val x: Id[{*} Cap] = ???
x(cap => cap.use()) // error
}

def test2(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: X -> T) -> T

val x: Id[{io} Cap] = ???
x(cap => cap.use()) // error
}

def test3(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {io} X -> T) -> T

val x: Id[{io} Cap] = ???
x(cap => cap.use()) // ok
}

def test4(io: {*} Cap, fs: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {io} X -> T) -> T

val x: Id[{io, fs} Cap] = ???
x(cap => cap.use()) // error
}
14 changes: 14 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-cov.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trait Cap

def test1(io: {*} Cap) = {
type Op[X] = [T] -> Unit -> X
val f: Op[{io} Cap] = ???
val x: [T] -> Unit -> ({io} Cap) = f // error
}

def test2(io: {*} Cap) = {
type Op[X] = [T] -> Unit -> {io} X
val f: Op[{io} Cap] = ???
val x: Unit -> ({io} Cap) = f[Unit] // error
val x1: {io} Unit -> ({io} Cap) = f[Unit] // ok
}
19 changes: 19 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-cs.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
trait Cap { def use(): Int }

def test1(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {io} X -> T) -> T

val x: Id[{io} Cap] = ???
val f: ({*} Cap) -> Unit = ???
x(f) // ok
// actual: {*} Cap -> Unit
// expected: {io} box {io} Cap -> Unit
}

def test2(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {*} X -> T) -> T

val x: Id[{*} Cap] = ???
val f: ({io} Cap) -> Unit = ???
x(f) // error
}
23 changes: 23 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-depfun.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
trait Cap { def use(): Int }

def test1(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {io} X -> T) -> T

val x: Id[{io} Cap] = ???
x(cap => cap.use()) // ok
}

def test2(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {io} (x: X) -> T) -> T

val x: Id[{io} Cap] = ???
x(cap => cap.use())
// should work when the expected type is a dependent function
}

def test3(io: {*} Cap): Unit = {
type Id[X] = [T] -> (op: {} (x: X) -> T) -> T

val x: Id[{io} Cap] = ???
x(cap => cap.use()) // error
}
13 changes: 13 additions & 0 deletions tests/neg-custom-args/captures/box-adapt-typefun.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
trait Cap { def use(): Int }

def test1(io: {*} Cap): Unit = {
type Op[X] = [T] -> X -> Unit
val f: [T] -> ({io} Cap) -> Unit = ???
val op: Op[{io} Cap] = f // error
}

def test2(io: {*} Cap): Unit = {
type Lazy[X] = [T] -> Unit -> X
val f: Lazy[{io} Cap] = ???
val test: [T] -> Unit -> ({io} Cap) = f // error
}
Loading