@@ -117,85 +117,150 @@ class CheckCaptures extends Recheck:
117117
118118 override  def  transformType (tp : Type , inferred : Boolean , boxed : Boolean )(using  Context ):  Type  = 
119119
120-  def  addInnerVars ( tp : Type ) :   Type   =  tp  match 
121-  case  tp  @   AppliedType (tycon, args)  => 
122-  tp.derivedAppliedType(tycon, args.map(addVars(_, boxed  =  true ))) 
123-  case  tp  @   RefinedType (core, rname, rinfo)  => 
124-  val   rinfo1   =  addVars(rinfo )
125-  if  defn.isFunctionType(tp)  then 
126-  rinfo1.toFunctionType(isJava  =   false , alwaysDependent  =   true ) 
127-    else 
128-   tp.derivedRefinedType(addInnerVars(core), rname, rinfo1 )
129-  case  tp :  MethodType   => 
130-  tp.derivedLambdaType( 
131-   paramInfos  =  tp.paramInfos.mapConserve(addVars(_)), 
132-   resType  =  addVars(tp.resType)) 
133-    case   tp :  PolyType   => 
134-  tp.derivedLambdaType( 
135-  resType  =  addVars(tp.resType)) 
136-    case   tp :  ExprType   => 
137-  tp.derivedExprType(resType  =  addVars(tp.resType)) 
138-    case  _  => 
139-  tp 
140- 
141-  /**  Turn plain function types into dependent function types, so that  
142-  * we can refer to the parameter in capture sets  
120+  def  depFun ( tycon : Type ,  argTypes :  List [ Type ],  resType :  Type ) :   Type   = 
121+  MethodType .companion( 
122+   isContextual  =  defn.isContextFunctionClass(tycon.classSymbol), 
123+   isErased  =  defn.isErasedFunctionClass(tycon.classSymbol) 
124+  )(argTypes, resType )
125+  .toFunctionType(isJava  =   false , alwaysDependent  =   true ) 
126+ 
127+  def   box ( tp :  Type ) :   Type   =  tp  match 
128+  case   CapturingType (parent, refs,  false )  =>   CapturingType (parent, refs,  true )
129+  case  _  =>  tp 
130+ 
131+  /**  Perform the following transformation steps everywhere in a type:  
132+  * 1. Drop retains annotations  
133+  * 2. Turn plain function types into dependent function types, so that  
134+  *   we can refer to their parameter in capture sets. Currently this is  
135+  *  only done at the toplevel, i.e. for function types that are not  
136+  * themselves argument types of other function types. Without this restriction  
137+  *   boxmap-paper.scala fails. Need to figure out why.  
138+  * 3. Refine other class types C by adding capture set variables to their parameter getters  
139+  *   (see addCaptureRefinements)  
140+  * 4. Add capture set variables to all types that can be tracked 
141+   *  
142+  * Polytype bounds are only cleaned using step 1, but not otherwise transformed.  
143143 */  
144-  def  addFunctionRefinements (tp : Type ):  Type  =  tp match 
145-  case  tp @  AppliedType (tycon, args) => 
146-  if  defn.isNonRefinedFunction(tp) then 
147-  MethodType .companion(
148-  isContextual =  defn.isContextFunctionClass(tycon.classSymbol),
149-  isErased =  defn.isErasedFunctionClass(tycon.classSymbol)
150-  )(args.init, addFunctionRefinements(args.last))
151-  .toFunctionType(isJava =  false , alwaysDependent =  true )
152-  .showing(i " add function refinement  $tp -->  $result" , capt)
153-  else 
154-  tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_)))
155-  case  tp @  RefinedType (core, rname, rinfo) if  ! defn.isFunctionType(tp) => 
156-  tp.derivedRefinedType(
157-  addFunctionRefinements(core), rname, addFunctionRefinements(rinfo))
158-  case  tp : MethodOrPoly  => 
159-  tp.derivedLambdaType(resType =  addFunctionRefinements(tp.resType))
160-  case  tp : ExprType  => 
161-  tp.derivedExprType(resType =  addFunctionRefinements(tp.resType))
162-  case  _ => 
163-  tp
144+  def  mapInferred  =  new  TypeMap : 
164145
165-  /**  Refine a possibly applied class type C where the class has tracked parameters 
166-  * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } 
167-  * where CV_1, ..., CV_n are fresh capture sets. 
168-  */  
169-  def  addCaptureRefinements (tp : Type ):  Type  =  tp.stripped match 
170-  case  _ : TypeRef  |  _ : AppliedType  if  tp.typeSymbol.isClass => 
171-  val  cls  =  tp.typeSymbol.asClass
172-  cls.paramGetters.foldLeft(tp) { (core, getter) => 
173-  if  getter.termRef.isTracked then 
174-  val  getterType  =  tp.memberInfo(getter).strippedDealias
175-  RefinedType (core, getter.name, CapturingType (getterType, CaptureSet .Var (), boxed =  false ))
176-  .showing(i " add capture refinement  $tp -->  $result" , capt)
177-  else 
178-  core
179-  }
180-  case  _ => 
181-  tp
146+  /**  Drop @retains annotations everywhere */  
147+  object  cleanup  extends  TypeMap : 
148+  def  apply (t : Type ) =  t match 
149+  case  AnnotatedType (parent, annot) if  annot.symbol ==  defn.RetainsAnnot  => 
150+  apply(parent)
151+  case  _ => 
152+  mapOver(t)
182153
183-  def  addVars (tp : Type , boxed : Boolean  =  false ):  Type  = 
184-  var  tp1  =  addInnerVars(tp)
185-  val  tp2  =  addCaptureRefinements(tp1)
186-  if  tp1.canHaveInferredCapture
187-  then  CapturingType (tp2, CaptureSet .Var (), boxed)
188-  else  tp2
154+  /**  Refine a possibly applied class type C where the class has tracked parameters 
155+  * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } 
156+  * where CV_1, ..., CV_n are fresh capture sets. 
157+  */  
158+  def  addCaptureRefinements (tp : Type ):  Type  =  tp match 
159+  case  _ : TypeRef  |  _ : AppliedType  if  tp.typeParams.isEmpty => 
160+  tp.typeSymbol match 
161+  case  cls : ClassSymbol  if  ! defn.isFunctionClass(cls) => 
162+  cls.paramGetters.foldLeft(tp) { (core, getter) => 
163+  if  getter.termRef.isTracked then 
164+  val  getterType  =  tp.memberInfo(getter).strippedDealias
165+  RefinedType (core, getter.name, CapturingType (getterType, CaptureSet .Var (), boxed =  false ))
166+  .showing(i " add capture refinement  $tp -->  $result" , capt)
167+  else 
168+  core
169+  }
170+  case  _ =>  tp
171+  case  _ =>  tp
172+ 
173+  /**  Should a capture set variable be added on type `tp`? */  
174+  def  canHaveInferredCapture (tp : Type ):  Boolean  = 
175+  tp.typeParams.isEmpty &&  tp.match 
176+  case  tp : (TypeRef  |  AppliedType ) => 
177+  val  sym  =  tp.typeSymbol
178+  if  sym.isClass then  ! sym.isValueClass &&  sym !=  defn.AnyClass 
179+  else  canHaveInferredCapture(tp.superType.dealias)
180+  case  tp : (RefinedOrRecType  |  MatchType ) => 
181+  canHaveInferredCapture(tp.underlying)
182+  case  tp : AndType  => 
183+  canHaveInferredCapture(tp.tp1) &&  canHaveInferredCapture(tp.tp2)
184+  case  tp : OrType  => 
185+  canHaveInferredCapture(tp.tp1) ||  canHaveInferredCapture(tp.tp2)
186+  case  _ => 
187+  false 
188+ 
189+  /**  Add a capture set variable to `tp` if necessary, or maybe pull out 
190+  * an embedded capture set variables from a part of `tp`. 
191+  */  
192+  def  addVar (tp : Type ) =  tp match 
193+  case  tp @  RefinedType (parent @  CapturingType (parent1, refs, boxed), rname, rinfo) => 
194+  CapturingType (tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed)
195+  case  tp : RecType  => 
196+  tp.parent match 
197+  case  CapturingType (parent1, refs, boxed) => 
198+  CapturingType (tp.derivedRecType(parent1), refs, boxed)
199+  case  _ => 
200+  tp //  can return `tp` here since unlike RefinedTypes, RecTypes are never created
201+  //  by `mapInferred`. Hence if the underlying type admits capture variables
202+  //  a variable was already added, and the first case above would apply.
203+  case  AndType (CapturingType (parent1, refs1, boxed1), CapturingType (parent2, refs2, boxed2)) => 
204+  assert(refs1.asVar.elems.isEmpty)
205+  assert(refs2.asVar.elems.isEmpty)
206+  assert(boxed1 ==  boxed2)
207+  CapturingType (AndType (parent1, parent2), refs1, boxed1)
208+  case  tp @  OrType (CapturingType (parent1, refs1, boxed1), CapturingType (parent2, refs2, boxed2)) => 
209+  assert(refs1.asVar.elems.isEmpty)
210+  assert(refs2.asVar.elems.isEmpty)
211+  assert(boxed1 ==  boxed2)
212+  CapturingType (OrType (parent1, parent2, tp.isSoft), refs1, boxed1)
213+  case  tp @  OrType (CapturingType (parent1, refs1, boxed1), tp2) => 
214+  CapturingType (OrType (parent1, tp2, tp.isSoft), refs1, boxed1)
215+  case  tp @  OrType (tp1, CapturingType (parent2, refs2, boxed2)) => 
216+  CapturingType (OrType (tp1, parent2, tp.isSoft), refs2, boxed2)
217+  case  _ if  canHaveInferredCapture(tp) => 
218+  CapturingType (tp, CaptureSet .Var (), boxed =  false )
219+  case  _ => 
220+  tp
189221
190-  if  inferred then 
191-  val  cleanup  =  new  TypeMap : 
192-  def  apply (t : Type ) =  t match 
222+  var  isTopLevel  =  true 
223+ 
224+  def  mapNested (ts : List [Type ]):  List [Type ] = 
225+  val  saved  =  isTopLevel
226+  isTopLevel =  false 
227+  try  ts.mapConserve(this ) finally  isTopLevel =  saved
228+ 
229+  def  apply (t : Type ) = 
230+  val  t1  =  t match 
193231 case  AnnotatedType (parent, annot) if  annot.symbol ==  defn.RetainsAnnot  => 
194232 apply(parent)
233+  case  tp @  AppliedType (tycon, args) => 
234+  val  tycon1  =  this (tycon)
235+  if  defn.isNonRefinedFunction(tp) then 
236+  val  args1  =  mapNested(args.init)
237+  val  res1  =  this (args.last)
238+  if  isTopLevel then 
239+  depFun(tycon1, args1, res1)
240+  .showing(i " add function refinement  $tp -->  $result" , capt)
241+  else 
242+  tp.derivedAppliedType(tycon1, args1 :+  res1)
243+  else 
244+  tp.derivedAppliedType(tycon1, args.mapConserve(arg =>  box(this (arg))))
245+  case  tp @  RefinedType (core, rname, rinfo) if  defn.isFunctionType(tp) => 
246+  apply(rinfo).toFunctionType(isJava =  false , alwaysDependent =  true )
247+  case  tp : MethodType  => 
248+  tp.derivedLambdaType(
249+  paramInfos =  mapNested(tp.paramInfos),
250+  resType =  this (tp.resType))
251+  case  tp : TypeLambda  => 
252+  //  Don't recurse into parameter bounds, just cleanup any stray retains annotations
253+  tp.derivedLambdaType(
254+  paramInfos =  tp.paramInfos.mapConserve(cleanup(_).bounds),
255+  resType =  this (tp.resType))
195256 case  _ => 
196257 mapOver(t)
197-  addVars(addFunctionRefinements(cleanup(tp)), boxed)
198-  .showing(i " reinfer  $tp -->  $result" , capt)
258+  addVar(addCaptureRefinements(t1))
259+  end  mapInferred 
260+ 
261+  if  inferred then 
262+  val  tp1  =  mapInferred(tp)
263+  if  boxed then  box(tp1) else  tp1
199264 else 
200265 def  setBoxed (t : Type ) =  t match 
201266 case  AnnotatedType (_, annot) if  annot.symbol ==  defn.RetainsAnnot  => 
0 commit comments