Skip to content
88 changes: 54 additions & 34 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,22 @@ extension (tp: Type)
case tp: SingletonCaptureRef => tp.captureSetOfInfo
case _ => CaptureSet.ofType(tp, followResult = false)

/** The deep capture set of a type.
* For singleton capabilities `x` and reach capabilities `x*`, this is `{x*}`, provided
* the underlying capture set resulting from traversing the type is non-empty.
* For other types this is the union of all covariant capture sets embedded
* in the type, as computed by `CaptureSet.ofTypeDeeply`.
/** The deep capture set of a type. This is by default the union of all
* covariant capture sets embedded in the widened type, as computed by
* `CaptureSet.ofTypeDeeply`. If that set is nonempty, and the type is
* a singleton capability `x` or a reach capability `x*`, the deep capture
* set can be narrowed to`{x*}`.
*/
def deepCaptureSet(using Context): CaptureSet =
val dcs = CaptureSet.ofTypeDeeply(tp)
if dcs.isAlwaysEmpty then dcs
val dcs = CaptureSet.ofTypeDeeply(tp.widen.stripCapturing)
if dcs.isAlwaysEmpty then tp.captureSet
else tp match
case tp @ ReachCapability(_) => tp.singletonCaptureSet
case tp: SingletonCaptureRef => tp.reach.singletonCaptureSet
case _ => dcs
case tp @ ReachCapability(_) =>
tp.singletonCaptureSet
case tp: SingletonCaptureRef if tp.isTrackableRef =>
tp.reach.singletonCaptureSet
case _ =>
tp.captureSet ++ dcs

/** A type capturing `ref` */
def capturing(ref: CaptureRef)(using Context): Type =
Expand Down Expand Up @@ -273,6 +276,29 @@ extension (tp: Type)
case _ =>
tp

/** The first element of this path type */
final def pathRoot(using Context): Type = tp.dealiasKeepAnnots match
case tp1: NamedType if tp1.symbol.owner.isClass => tp1.prefix.pathRoot
case ReachCapability(tp1) => tp1.pathRoot
case _ => tp

/** If this part starts with `C.this`, the class `C`.
* Otherwise, if it starts with a reference `r`, `r`'s owner.
* Otherwise NoSymbol.
*/
final def pathOwner(using Context): Symbol = pathRoot match
case tp1: NamedType => tp1.symbol.owner
case tp1: ThisType => tp1.cls
case _ => NoSymbol

final def isParamPath(using Context): Boolean = tp.dealias match
case tp1: NamedType =>
tp1.prefix match
case _: ThisType | NoPrefix =>
tp1.symbol.is(Param) || tp1.symbol.is(ParamAccessor)
case prefix => prefix.isParamPath
case _ => false

/** If this is a unboxed capturing type with nonempty capture set, its boxed version.
* Or, if type is a TypeBounds of capturing types, the version where the bounds are boxed.
* The identity for all other types.
Expand Down Expand Up @@ -468,29 +494,23 @@ extension (tp: Type)
end CheckContraCaps

object narrowCaps extends TypeMap:
/** Has the variance been flipped at this point? */
private var isFlipped: Boolean = false

def apply(t: Type) =
val saved = isFlipped
try
if variance <= 0 then isFlipped = true
t.dealias match
case t1 @ CapturingType(p, cs) if cs.isUniversal && !isFlipped =>
t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet)
case t1 @ FunctionOrMethod(args, res @ Existential(_, _))
if args.forall(_.isAlwaysPure) =>
// Also map existentials in results to reach capabilities if all
// preceding arguments are known to be always pure
apply(t1.derivedFunctionOrMethod(args, Existential.toCap(res)))
case Existential(_, _) =>
t
case _ => t match
case t @ CapturingType(p, cs) =>
t.derivedCapturingType(apply(p), cs) // don't map capture set variables
case t =>
mapOver(t)
finally isFlipped = saved
if variance <= 0 then t
else t.dealiasKeepAnnots match
case t @ CapturingType(p, cs) if cs.isUniversal =>
t.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet)
case t @ AnnotatedType(parent, ann) =>
// Don't map annotations, which includes capture sets
t.derivedAnnotatedType(this(parent), ann)
case t @ FunctionOrMethod(args, res @ Existential(_, _))
if args.forall(_.isAlwaysPure) =>
// Also map existentials in results to reach capabilities if all
// preceding arguments are known to be always pure
apply(t.derivedFunctionOrMethod(args, Existential.toCap(res)))
case Existential(_, _) =>
t
case _ =>
mapOver(t)
end narrowCaps

ref match
Expand Down Expand Up @@ -639,8 +659,8 @@ object CapsOfApply:
class AnnotatedCapability(annot: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context) =
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))
def unapply(tree: AnnotatedType)(using Context): Option[SingletonCaptureRef] = tree match
case AnnotatedType(parent: SingletonCaptureRef, ann) if ann.symbol == annot => Some(parent)
def unapply(tree: AnnotatedType)(using Context): Option[CaptureRef] = tree match
case AnnotatedType(parent: CaptureRef, ann) if ann.symbol == annot => Some(parent)
case _ => None

/** An extractor for `ref @annotation.internal.reachCapability`, which is used to express
Expand Down
26 changes: 17 additions & 9 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ object CaptureSet:
case ref: (TermRef | TermParamRef) if ref.isMaxCapability =>
if ref.isTrackableRef then ref.singletonCaptureSet
else CaptureSet.universal
case ReachCapability(ref1) => deepCaptureSet(ref1.widen)
case ReachCapability(ref1) => ref1.widen.deepCaptureSet
.showing(i"Deep capture set of $ref: ${ref1.widen} = $result", capt)
case _ => ofType(ref.underlying, followResult = true)

Expand Down Expand Up @@ -1115,17 +1115,25 @@ object CaptureSet:

/** The deep capture set of a type is the union of all covariant occurrences of
* capture sets. Nested existential sets are approximated with `cap`.
* NOTE: The traversal logic needs to be in sync with narrowCaps in CaptureOps, which
* replaces caps with reach capabilties.
*/
def ofTypeDeeply(tp: Type)(using Context): CaptureSet =
val collect = new TypeAccumulator[CaptureSet]:
def apply(cs: CaptureSet, t: Type) = t.dealias match
case t @ CapturingType(p, cs1) =>
val cs2 = apply(cs, p)
if variance > 0 then cs2 ++ cs1 else cs2
case t @ Existential(_, _) =>
apply(cs, Existential.toCap(t))
case _ =>
foldOver(cs, t)
def apply(cs: CaptureSet, t: Type) =
if variance <= 0 then cs
else t.dealias match
case t @ CapturingType(p, cs1) =>
this(cs, p) ++ cs1
case t @ AnnotatedType(parent, ann) =>
this(cs, parent)
case t @ FunctionOrMethod(args, res @ Existential(_, _))
if args.forall(_.isAlwaysPure) =>
this(cs, Existential.toCap(res))
case t @ Existential(_, _) =>
cs
case _ =>
foldOver(cs, t)
collect(CaptureSet.empty, tp)

type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]]
Expand Down
Loading
Loading