Skip to content

Add restricted capabilities x.only[C] #23485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
4 changes: 0 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2321,10 +2321,6 @@ object desugar {
Annotated(
AppliedTypeTree(ref(defn.SeqType), t),
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_REACH then
Annotated(t, New(ref(defn.ReachCapabilityAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_READONLY then
Annotated(t, New(ref(defn.ReadOnlyCapabilityAnnot.typeRef), Nil :: Nil))
else
assert(ctx.mode.isExpr || ctx.reporter.errorsReported || ctx.mode.is(Mode.Interactive), ctx.mode)
Select(t, op.name)
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
annot.putAttachment(RetainsAnnot, ())
Annotated(parent, annot)

def makeReachAnnot()(using Context): Tree =
New(ref(defn.ReachCapabilityAnnot.typeRef), Nil :: Nil)

def makeReadOnlyAnnot()(using Context): Tree =
New(ref(defn.ReadOnlyCapabilityAnnot.typeRef), Nil :: Nil)

def makeOnlyAnnot(qid: Tree)(using Context) =
New(AppliedTypeTree(ref(defn.OnlyCapabilityAnnot.typeRef), qid :: Nil), Nil :: Nil)

def makeConstructor(tparams: List[TypeDef], vparamss: List[List[ValDef]], rhs: Tree = EmptyTree)(using Context): DefDef =
DefDef(nme.CONSTRUCTOR, joinParams(tparams, vparamss), TypeTree(), rhs)

Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CCState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class CCState:
def start(): Unit =
iterCount = 1

private var mySepCheck = false

/** Are we currently running separation checks? */
def isSepCheck = mySepCheck

def inSepCheck(op: => Unit): Unit =
val saved = mySepCheck
mySepCheck = true
try op finally mySepCheck = saved

// ------ Global counters -----------------------

/** Next CaptureSet.Var id */
Expand Down
203 changes: 180 additions & 23 deletions compiler/src/dotty/tools/dotc/cc/Capability.scala

Large diffs are not rendered by default.

74 changes: 49 additions & 25 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ extension (tp: Type)
tp1.toCapability.reach
case ReadOnlyCapability(tp1) =>
tp1.toCapability.readOnly
case OnlyCapability(tp1, cls) =>
tp1.toCapability.restrict(cls)
case ref: TermRef if ref.isCapRef =>
GlobalCap
case ref: Capability if ref.isTrackableRef =>
Expand Down Expand Up @@ -288,7 +290,7 @@ extension (tp: Type)
def forceBoxStatus(boxed: Boolean)(using Context): Type = tp.widenDealias match
case tp @ CapturingType(parent, refs) if tp.isBoxed != boxed =>
val refs1 = tp match
case ref: Capability if ref.isTracked || ref.isReach || ref.isReadOnly =>
case ref: Capability if ref.isTracked || ref.isInstanceOf[DerivedCapability] =>
ref.singletonCaptureSet
case _ => refs
CapturingType(parent, refs1, boxed)
Expand Down Expand Up @@ -374,7 +376,7 @@ extension (tp: Type)

def derivesFromCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_Capability)
def derivesFromMutable(using Context): Boolean = derivesFromCapTrait(defn.Caps_Mutable)
def derivesFromSharedCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_SharedCapability)
def derivesFromSharedCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_Sharable)

/** Drop @retains annotations everywhere */
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
Expand Down Expand Up @@ -440,6 +442,30 @@ extension (tp: Type)
def dropUseAndConsumeAnnots(using Context): Type =
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)

/** If `tp` is a function or method, a type of the same kind with the given
* argument and result types.
*/
def derivedFunctionOrMethod(argTypes: List[Type], resType: Type)(using Context): Type = tp match
case tp @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp) =>
val args1 = argTypes :+ resType
if args.corresponds(args1)(_ eq _) then tp
else tp.derivedAppliedType(tycon, args1)
case tp @ defn.RefinedFunctionOf(rinfo) =>
val rinfo1 = rinfo.derivedFunctionOrMethod(argTypes, resType)
if rinfo1 eq rinfo then tp
else if rinfo1.isInstanceOf[PolyType] then tp.derivedRefinedType(refinedInfo = rinfo1)
else rinfo1.toFunctionType(alwaysDependent = true)
case tp: MethodType =>
tp.derivedLambdaType(paramInfos = argTypes, resType = resType)
case tp: PolyType =>
assert(argTypes.isEmpty)
tp.derivedLambdaType(resType = resType)
case _ =>
tp

def classifier(using Context): ClassSymbol =
tp.classSymbols.map(_.classifier).foldLeft(defn.AnyClass)(leastClassifier)

extension (tp: MethodType)
/** A method marks an existential scope unless it is the prefix of a curried method */
def marksExistentialScope(using Context): Boolean =
Expand Down Expand Up @@ -471,6 +497,16 @@ extension (cls: ClassSymbol)
val selfType = bc.givenSelfType
bc.is(CaptureChecked) && selfType.exists && selfType.captureSet.elems == refs.elems

def isClassifiedCapabilityClass(using Context): Boolean =
cls.derivesFrom(defn.Caps_Capability) && cls.parentSyms.contains(defn.Caps_Classifier)

def classifier(using Context): ClassSymbol =
if cls.derivesFrom(defn.Caps_Capability) then
cls.baseClasses
.filter(_.parentSyms.contains(defn.Caps_Classifier))
.foldLeft(defn.AnyClass)(leastClassifier)
else defn.AnyClass

extension (sym: Symbol)

/** This symbol is one of `retains` or `retainsCap` */
Expand Down Expand Up @@ -585,7 +621,6 @@ abstract class AnnotatedCapability(annotCls: Context ?=> ClassSymbol):
def unapply(tree: AnnotatedType)(using Context): Option[Type] = tree match
case AnnotatedType(parent: Type, ann) if ann.hasSymbol(annotCls) => Some(parent)
case _ => None

end AnnotatedCapability

/** An extractor for `ref @readOnlyCapability`, which is used to express
Expand All @@ -603,6 +638,17 @@ object ReachCapability extends AnnotatedCapability(defn.ReachCapabilityAnnot)
*/
object MaybeCapability extends AnnotatedCapability(defn.MaybeCapabilityAnnot)

object OnlyCapability:
def apply(tp: Type, cls: ClassSymbol)(using Context): AnnotatedType =
AnnotatedType(tp,
Annotation(defn.OnlyCapabilityAnnot.typeRef.appliedTo(cls.typeRef), Nil, util.Spans.NoSpan))

def unapply(tree: AnnotatedType)(using Context): Option[(Type, ClassSymbol)] = tree match
case AnnotatedType(parent: Type, ann) if ann.hasSymbol(defn.OnlyCapabilityAnnot) =>
Some((parent, ann.tree.tpe.argTypes.head.classSymbol.asClass))
case _ => None
end OnlyCapability

/** An extractor for all kinds of function types as well as method and poly types.
* It includes aliases of function types such as `=>`. TODO: Can we do without?
* @return 1st half: The argument types or empty if this is a type function
Expand All @@ -616,28 +662,6 @@ object FunctionOrMethod:
case defn.RefinedFunctionOf(rinfo) => unapply(rinfo)
case _ => None

/** If `tp` is a function or method, a type of the same kind with the given
* argument and result types.
*/
extension (self: Type)
def derivedFunctionOrMethod(argTypes: List[Type], resType: Type)(using Context): Type = self match
case self @ AppliedType(tycon, args) if defn.isNonRefinedFunction(self) =>
val args1 = argTypes :+ resType
if args.corresponds(args1)(_ eq _) then self
else self.derivedAppliedType(tycon, args1)
case self @ defn.RefinedFunctionOf(rinfo) =>
val rinfo1 = rinfo.derivedFunctionOrMethod(argTypes, resType)
if rinfo1 eq rinfo then self
else if rinfo1.isInstanceOf[PolyType] then self.derivedRefinedType(refinedInfo = rinfo1)
else rinfo1.toFunctionType(alwaysDependent = true)
case self: MethodType =>
self.derivedLambdaType(paramInfos = argTypes, resType = resType)
case self: PolyType =>
assert(argTypes.isEmpty)
self.derivedLambdaType(resType = resType)
case _ =>
self

/** An extractor for a contains argument */
object ContainsImpl:
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =
Expand Down
82 changes: 78 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ sealed abstract class CaptureSet extends Showable:
*/
def owner: Symbol

/** If this set is a variable: Drop capabilities that are known to be empty
* This is called during separation checking so that capabilities that turn
* out to be always empty because of conflicting clasisifiers don't contribute
* to peaks. We can't do it before that since classifiers are set during
* capture checking.
*/
def dropEmpties()(using Context): this.type

/** Is this capture set definitely non-empty? */
final def isNotEmpty: Boolean = !elems.isEmpty

Expand Down Expand Up @@ -210,6 +218,7 @@ sealed abstract class CaptureSet extends Showable:

protected def addIfHiddenOrFail(elem: Capability)(using ctx: Context, vs: VarState): Boolean =
elems.exists(_.maxSubsumes(elem, canAddHidden = true))
|| elem.isKnownEmpty
|| failWith(IncludeFailure(this, elem))

/** If this is a variable, add `cs` as a dependent set */
Expand Down Expand Up @@ -403,11 +412,33 @@ sealed abstract class CaptureSet extends Showable:

def maybe(using Context): CaptureSet = map(MaybeMap())

def restrict(cls: ClassSymbol)(using Context): CaptureSet = map(RestrictMap(cls))

def readOnly(using Context): CaptureSet =
val res = map(ReadOnlyMap())
if mutability != Ignored then res.mutability = Reader
res

def transClassifiers(using Context): Classifiers =
def elemClassifiers =
(ClassifiedAs(Nil) /: elems.map(_.transClassifiers))(joinClassifiers)
if ccState.isSepCheck then
dropEmpties()
elemClassifiers
else if isConst then
elemClassifiers
else
UnknownClassifier

def tryClassifyAs(cls: ClassSymbol)(using Context): Boolean =
elems.forall(_.tryClassifyAs(cls))

def adoptClassifier(cls: ClassSymbol)(using Context): Unit =
for elem <- elems do
elem.stripReadOnly match
case fresh: FreshCap => fresh.hiddenSet.adoptClassifier(cls)
case _ =>

/** A bad root `elem` is inadmissible as a member of this set. What is a bad roots depends
* on the value of `rootLimit`.
* If the limit is null, all capture roots are good.
Expand Down Expand Up @@ -557,6 +588,8 @@ object CaptureSet:

def owner = NoSymbol

def dropEmpties()(using Context) = this

private var isComplete = true

def setMutable()(using Context): Unit =
Expand Down Expand Up @@ -641,6 +674,16 @@ object CaptureSet:

def isMaybeSet = false // overridden in BiMapped

private var emptiesDropped = false

def dropEmpties()(using Context): this.type =
if !emptiesDropped then
emptiesDropped = true
for elem <- elems do
if elem.isKnownEmpty then
elems -= empty
this

/** A handler to be invoked if the root reference `cap` is added to this set */
var rootAddedHandler: () => Context ?=> Unit = () => ()

Expand All @@ -649,6 +692,25 @@ object CaptureSet:
*/
private[CaptureSet] var rootLimit: Symbol | Null = null

private var myClassifier: ClassSymbol = defn.AnyClass
def classifier: ClassSymbol = myClassifier

private def narrowClassifier(cls: ClassSymbol)(using Context): Unit =
val newClassifier = leastClassifier(classifier, cls)
if newClassifier == defn.NothingClass then
println(i"conflicting classifications for $this, was $classifier, now $cls")
myClassifier = newClassifier

override def adoptClassifier(cls: ClassSymbol)(using Context): Unit =
if !classifier.isSubClass(cls) then // serves as recursion brake
narrowClassifier(cls)
super.adoptClassifier(cls)

override def tryClassifyAs(cls: ClassSymbol)(using Context): Boolean =
classifier.isSubClass(cls)
|| super.tryClassifyAs(cls)
&& { narrowClassifier(cls); true }

/** A handler to be invoked when new elems are added to this set */
var newElemAddedHandler: Capability => Context ?=> Unit = _ => ()

Expand Down Expand Up @@ -680,14 +742,15 @@ object CaptureSet:
addIfHiddenOrFail(elem)
else if !levelOK(elem) then
failWith(IncludeFailure(this, elem, levelError = true)) // or `elem` is not visible at the level of the set.
else if !elem.tryClassifyAs(classifier) then
failWith(IncludeFailure(this, elem))
else
// id == 108 then assert(false, i"trying to add $elem to $this")
assert(elem.isWellformed, elem)
assert(!this.isInstanceOf[HiddenSet] || summon[VarState].isSeparating, summon[VarState])
includeElem(elem)
if isBadRoot(rootLimit, elem) then
rootAddedHandler()
newElemAddedHandler(elem)
val normElem = if isMaybeSet then elem else elem.stripMaybe
// assert(id != 5 || elems.size != 3, this)
val res = deps.forall: dep =>
Expand Down Expand Up @@ -1344,9 +1407,10 @@ object CaptureSet:

/** A template for maps on capabilities where f(c) <: c and f(f(c)) = c */
private abstract class NarrowingCapabilityMap(using Context) extends BiTypeMap:

def apply(t: Type) = mapOver(t)

protected def isSameMap(other: BiTypeMap) = other.getClass == getClass

override def fuse(next: BiTypeMap)(using Context) = next match
case next: Inverse if next.inverse.getClass == getClass => Some(IdentityTypeMap)
case next: NarrowingCapabilityMap if next.getClass == getClass => Some(this)
Expand All @@ -1358,8 +1422,8 @@ object CaptureSet:
def inverse = NarrowingCapabilityMap.this
override def toString = NarrowingCapabilityMap.this.toString ++ ".inverse"
override def fuse(next: BiTypeMap)(using Context) = next match
case next: NarrowingCapabilityMap if next.inverse.getClass == getClass => Some(IdentityTypeMap)
case next: NarrowingCapabilityMap if next.getClass == getClass => Some(this)
case next: NarrowingCapabilityMap if isSameMap(next.inverse) => Some(IdentityTypeMap)
case next: NarrowingCapabilityMap if isSameMap(next) => Some(this)
case _ => None

lazy val inverse = Inverse()
Expand All @@ -1375,6 +1439,13 @@ object CaptureSet:
override def mapCapability(c: Capability, deep: Boolean) = c.readOnly
override def toString = "ReadOnly"

private class RestrictMap(val cls: ClassSymbol)(using Context) extends NarrowingCapabilityMap:
override def mapCapability(c: Capability, deep: Boolean) = c.restrict(cls)
override def toString = "Restrict"
override def isSameMap(other: BiTypeMap) = other match
case other: RestrictMap => cls == other.cls
case _ => false

/* Not needed:
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =
CaptureSet.empty
Expand Down Expand Up @@ -1402,6 +1473,9 @@ object CaptureSet:
case Reach(c1) =>
c1.widen.deepCaptureSet(includeTypevars = true)
.showing(i"Deep capture set of $c: ${c1.widen} = ${result}", capt)
case Restricted(c1, cls) =>
if cls == defn.NothingClass then CaptureSet.empty
else c1.captureSetOfInfo.restrict(cls) // todo: should we simplify using subsumption here?
case ReadOnly(c1) =>
c1.captureSetOfInfo.readOnly
case Maybe(c1) =>
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CapturingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ object CapturingType:
apply(parent1, refs ++ refs1, boxed)
case _ =>
if parent.derivesFromMutable then refs.setMutable()
val classifier = parent.classifier
refs.adoptClassifier(parent.classifier)
AnnotatedType(parent, CaptureAnnotation(refs, boxed)(defn.RetainsAnnot))

/** An extractor for CapturingTypes. Capturing types are recognized if
Expand Down
13 changes: 11 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ object CheckCaptures:
report.error(em"Cannot form a reach capability from `cap`", ann.srcPos)
case ReadOnlyCapability(ref) =>
check(ref)
case OnlyCapability(ref, cls) =>
if !cls.isClassifiedCapabilityClass then
report.error(
em"""${ref.showRef}.only[${cls.name}] is not well-formed since $cls is not a classifier class.
|A classifier class is a class extending `caps.Capability` and directly extending `caps.Classifier`.""",
ann.srcPos)
check(ref)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", ann.srcPos)
ann.retainedSet.retainedElementsRaw.foreach(check)
Expand Down Expand Up @@ -1290,7 +1297,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case ExistentialSubsumesFailure(ex, other) =>
def since =
if other.isTerminalCapability then ""
else " since that capability is not a SharedCapability"
else " since that capability is not a `Sharable` capability"
i"""the existential capture root in ${ex.originalBinder.resType}
|cannot subsume the capability $other$since"""
case MutAdaptFailure(cs, lo, hi) =>
Expand Down Expand Up @@ -2014,7 +2021,9 @@ class CheckCaptures extends Recheck, SymTransformer:
end checker

checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
if ccConfig.useSepChecks then SepCheck(this).traverse(unit)
if ccConfig.useSepChecks then
ccState.inSepCheck:
SepCheck(this).traverse(unit)
if !ctx.reporter.errorsReported then
// We dont report errors here if previous errors were reported, because other
// errors often result in bad applied types, but flagging these bad types gives
Expand Down
Loading
Loading