Skip to content
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ object desugar {
val isCaseObject = mods.is(Case) && isObject
val isEnum = mods.isEnumClass && !mods.is(Module)
def isEnumCase = mods.isEnumCase
def isNonEnumCase = !isEnumCase && (isCaseClass || isCaseObject)
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.

Expand Down Expand Up @@ -483,7 +484,8 @@ object desugar {
val enumCompanionRef = TermRefTree()
val enumImport =
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
(enumImport :: enumStats, enumCases, enumCompanionRef)
val enumSpecMethods = EnumGetters()
(enumImport :: enumSpecMethods :: enumStats, enumCases, enumCompanionRef)
}
else (stats, Nil, EmptyTree)
}
Expand Down Expand Up @@ -621,10 +623,8 @@ object desugar {
var parents1 = parents
if (isEnumCase && parents.isEmpty)
parents1 = enumClassTypeRef :: Nil
if (isCaseClass | isCaseObject)
if (isNonEnumCase || isEnum)
parents1 = parents1 :+ scalaDot(str.Product.toTypeName) :+ scalaDot(nme.Serializable.toTypeName)
if (isEnum)
parents1 = parents1 :+ ref(defn.EnumClass.typeRef)

// derived type classes of non-module classes go to their companions
val (clsDerived, companionDerived) =
Expand Down Expand Up @@ -890,6 +890,9 @@ object desugar {
}
}

def enumGetters(getters: EnumGetters)(using Context): Tree =
flatTree(DesugarEnums.enumBaseMeths).withSpan(getters.span)

/** Transform extension construct to list of extension methods */
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
for mdef <- ext.methods yield
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ object DesugarEnums {
(ordinal, Nil)
}

def enumBaseMeths(using Context): List[Tree] =
if isJavaEnum then
enumLabelMeth(EmptyTree) :: Nil
else
ordinalMeth(EmptyTree) :: enumLabelMeth(EmptyTree) :: Nil

def param(name: TermName, typ: Type)(using Context): ValDef = param(name, TypeTree(typ))
def param(name: TermName, tpt: Tree)(using Context): ValDef = ValDef(name, tpt, EmptyTree).withFlags(Param)

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class Export(expr: Tree, selectors: List[ImportSelector])(implicit @constructorOnly src: SourceFile) extends Tree
case class ExtMethods(tparams: List[TypeDef], vparamss: List[List[ValDef]], methods: List[DefDef])(implicit @constructorOnly src: SourceFile) extends Tree
case class MacroTree(expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class EnumGetters()(implicit @constructorOnly src: SourceFile) extends Tree

case class ImportSelector(imported: Ident, renamed: Tree = EmptyTree, bound: Tree = EmptyTree)(implicit @constructorOnly src: SourceFile) extends Tree {
// TODO: Make bound a typed tree?
Expand Down Expand Up @@ -700,6 +701,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
cpy.Export(tree)(transform(expr), selectors)
case ExtMethods(tparams, vparamss, methods) =>
cpy.ExtMethods(tree)(transformSub(tparams), vparamss.mapConserve(transformSub(_)), transformSub(methods))
case enums: EnumGetters => enums
case ImportSelector(imported, renamed, bound) =>
cpy.ImportSelector(tree)(transformSub(imported), transform(renamed), transform(bound))
case Number(_, _) | TypedSplice(_) =>
Expand Down Expand Up @@ -761,6 +763,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this(x, expr)
case ExtMethods(tparams, vparamss, methods) =>
this(vparamss.foldLeft(this(x, tparams))(apply), methods)
case EnumGetters() =>
x
case ImportSelector(imported, renamed, bound) =>
this(this(this(x, imported), renamed), bound)
case Number(_, _) =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ enum ErrorMessageID extends java.lang.Enum[ErrorMessageID] {
ModifierNotAllowedForDefinitionID,
CannotExtendJavaEnumID,
InvalidReferenceInImplicitNotFoundAnnotationID,
TraitMayNotDefineNativeMethodID
TraitMayNotDefineNativeMethodID,
EnumGettersRedefinitionID

def errorNumber = ordinal - 2
}
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,10 @@ import ast.tpd
def explain = ""
}

class EnumGettersRedefinition(decl: Symbol)(using Context) extends NamingMsg(EnumGettersRedefinitionID):
def msg = em"redefinition of $decl: ${decl.info} in an ${hl("enum")}"
def explain = em"users may not supply their own definition for $decl when inside an ${hl("enum")}"

class DoubleDefinition(decl: Symbol, previousDecl: Symbol, base: Symbol)(using Context) extends NamingMsg(DoubleDefinitionID) {
def msg = {
def nameAnd = if (decl.name != previousDecl.name) " name and" else ""
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ object SymUtils {
self
}

def isScalaEnum(using Context): Boolean = self.is(Enum, butNot=JavaDefined)

/** Does this symbol refer to anonymous classes synthesized by enum desugaring? */
def isEnumAnonymClass(using Context): Boolean =
self.isAnonymousClass && (self.owner.name.eq(nme.DOLLAR_NEW) || self.owner.is(CaseVal))
Expand Down
32 changes: 31 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
lazy val accessors =
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
else clazz.caseAccessors
val isEnumCase = clazz.derivesFrom(defn.EnumClass) && clazz != defn.EnumClass
val isEnumCase = clazz.classParents.exists(_.typeSymbol.isScalaEnum)
val isEnumValue = isEnumCase && clazz.isAnonymousClass && clazz.classParents.head.classSymbol.is(Enum)
val isNonJavaEnumValue = isEnumValue && !clazz.derivesFrom(defn.JavaEnumClass)

Expand Down Expand Up @@ -513,6 +513,34 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
Match(param, cases)
}

/** For an enum T:
*
* def enumLabel(x: MirroredMonoType) = x.enumLabel
*
* For sealed trait with children of normalized types C_1, ..., C_n:
*
* def enumLabel(x: MirroredMonoType) = x match {
* case _: C_1 => "C_1"
* ...
* case _: C_n => "C_n"
* }
*
* Here, the normalized type of a class C is C[?, ...., ?] with
* a wildcard for each type parameter. The normalized type of an object
* O is O.type.
*/
def enumLabelBody(cls: Symbol, param: Tree)(using Context): Tree =
if (cls.is(Enum)) param.select(nme.enumLabel).ensureApplied
else {
val cases =
for ((child, idx) <- cls.children.zipWithIndex) yield {
val patType = if (child.isTerm) child.termRef else child.rawTypeRef
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(child.name.toString)))
}
Match(param, cases)
}

/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
* and `MirroredMonoType` and `ordinal` members.
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
Expand Down Expand Up @@ -564,6 +592,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
addParent(defn.Mirror_SumClass.typeRef)
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
ordinalBody(_, _))
addMethod(nme.enumLabel, MethodType(monoType.typeRef :: Nil, defn.StringType), cls,
enumLabelBody(_, _))
}

if (clazz.is(Module)) {
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,12 @@ trait Checking {
if (decl.matches(other) && !javaFieldMethodPair) {
def doubleDefError(decl: Symbol, other: Symbol): Unit =
if (!decl.info.isErroneous && !other.info.isErroneous)
report.error(DoubleDefinition(decl, other, cls), decl.srcPos)
if decl.owner.is(Enum, butNot=JavaDefined|Case) && decl.span.isSynthetic && (
decl.name == nme.ordinal || decl.name == nme.enumLabel)
then
report.error(EnumGettersRedefinition(decl), other.srcPos)
else
report.error(DoubleDefinition(decl, other, cls), decl.srcPos)
if (decl is Synthetic) doubleDefError(other, decl)
else doubleDefError(decl, other)
}
Expand Down
24 changes: 17 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,20 @@ class Namer { typer: Typer =>
tree.pushAttachment(ExpandedTree, expanded)
}
tree match {
case tree: DefTree => record(desugar.defTree(tree))
case tree: PackageDef => record(desugar.packageDef(tree))
case tree: ExtMethods => record(desugar.extMethods(tree))
case _ =>
case tree: DefTree => record(desugar.defTree(tree))
case tree: PackageDef => record(desugar.packageDef(tree))
case tree: ExtMethods => record(desugar.extMethods(tree))
case tree: EnumGetters => record(desugar.enumGetters(tree))
case _ =>
}
}

/** The expanded version of this tree, or tree itself if not expanded */
def expanded(tree: Tree)(using Context): Tree = tree match {
case _: DefTree | _: PackageDef | _: ExtMethods => tree.attachmentOrElse(ExpandedTree, tree)
case _ => tree
case _: DefTree | _: PackageDef | _: ExtMethods | _: EnumGetters =>
tree.attachmentOrElse(ExpandedTree, tree)
case _ =>
tree
}

/** For all class definitions `stat` in `xstats`: If the companion class is
Expand Down Expand Up @@ -925,11 +928,17 @@ class Namer { typer: Typer =>

val TypeDef(name, impl @ Template(constr, _, self, _)) = original

private val (params, rest): (List[Tree], List[Tree]) = impl.body.span {
private val (params, restOfBody): (List[Tree], List[Tree]) = impl.body.span {
case td: TypeDef => td.mods.is(Param)
case vd: ValDef => vd.mods.is(ParamAccessor)
case _ => false
}
private val (restAfterParents, rest): (List[Tree], List[Tree]) =
if original.mods.isEnumClass then
val (imports :: getters :: Nil, stats): @unchecked = restOfBody.splitAt(2)
(getters :: Nil, imports :: stats) // enum getters desugaring needs to test if a parent is java.lang.Enum
else
(Nil, restOfBody)

def init(): Context = index(params)

Expand Down Expand Up @@ -1196,6 +1205,7 @@ class Namer { typer: Typer =>
cls.setNoInitsFlags(parentsKind(parents), untpd.bodyKind(rest))
if (cls.isNoInitsClass) cls.primaryConstructor.setFlag(StableRealizable)
processExports(using localCtx)
index(restAfterParents)(using localCtx)
}
}

Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,7 @@ class Typer extends Namer
.withType(dummy.termRef)
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
if cls.derivesFrom(defn.EnumClass) then
if cls.isScalaEnum || firstParent.isScalaEnum then
checkEnum(cdef, cls, firstParent)
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)

Expand Down Expand Up @@ -2635,6 +2635,9 @@ class Typer extends Namer
case (stat: untpd.ExtMethods) :: rest =>
val xtree = stat.removeAttachment(ExpandedTree).get
traverse(xtree :: rest)
case (stat: untpd.EnumGetters) :: rest =>
val xtree = stat.removeAttachment(ExpandedTree).get
traverse(xtree :: rest)
case stat :: rest =>
val stat1 = typed(stat)(using ctx.exprContext(stat, exprOwner))
checkStatementPurity(stat1)(stat, exprOwner)
Expand Down
3 changes: 2 additions & 1 deletion library/src-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scala

/** A base trait of all enum classes */
/** A Product that also describes a label and ordinal */
@deprecated("scala.Enum is no longer supported", "3.0.0-M1")
trait Enum extends Product, Serializable:

/** A string uniquely identifying a case of an enum */
Expand Down
76 changes: 76 additions & 0 deletions library/src-bootstrapped/scala/deriving.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package scala

import quoted._

object deriving {

/** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents.
*/
sealed trait Mirror {

/** The mirrored *-type */
type MirroredMonoType

/** The name of the type */
type MirroredLabel <: String

/** The names of the product elements */
type MirroredElemLabels <: Tuple
}

object Mirror {

/** The Mirror for a sum type */
trait Sum extends Mirror { self =>
/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
def ordinal(x: MirroredMonoType): Int
/** The case label of the case class of `x`. For enums, `enumLabel(x) == x.enumLabel` */
def enumLabel(x: MirroredMonoType): String
}

/** The Mirror for a product type */
trait Product extends Mirror {

/** Create a new instance of type `T` with elements taken from product `p`. */
def fromProduct(p: scala.Product): MirroredMonoType
}

trait Singleton extends Product {
type MirroredMonoType = this.type
type MirroredType = this.type
type MirroredElemTypes = EmptyTuple
type MirroredElemLabels = EmptyTuple
def fromProduct(p: scala.Product) = this
}

/** A proxy for Scala 2 singletons, which do not inherit `Singleton` directly */
class SingletonProxy(val value: AnyRef) extends Product {
type MirroredMonoType = value.type
type MirroredType = value.type
type MirroredElemTypes = EmptyTuple
type MirroredElemLabels = EmptyTuple
def fromProduct(p: scala.Product) = value
}

type Of[T] = Mirror { type MirroredType = T; type MirroredMonoType = T ; type MirroredElemTypes <: Tuple }
type ProductOf[T] = Mirror.Product { type MirroredType = T; type MirroredMonoType = T ; type MirroredElemTypes <: Tuple }
type SumOf[T] = Mirror.Sum { type MirroredType = T; type MirroredMonoType = T; type MirroredElemTypes <: Tuple }
}

/** Helper class to turn arrays into products */
class ArrayProduct(val elems: Array[AnyRef]) extends Product {
def this(size: Int) = this(new Array[AnyRef](size))
def canEqual(that: Any): Boolean = true
def productElement(n: Int) = elems(n)
def productArity = elems.length
override def productIterator: Iterator[Any] = elems.iterator
def update(n: Int, x: Any) = elems(n) = x.asInstanceOf[AnyRef]
}

/** The empty product */
object EmptyProduct extends ArrayProduct(Array.emptyObjectArray)

/** Helper method to select a product element */
def productElement[T](x: Any, idx: Int) =
x.asInstanceOf[Product].productElement(idx).asInstanceOf[T]
}
3 changes: 3 additions & 0 deletions library/src-non-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ package scala
/** A base trait of all enum classes */
trait Enum extends Product, Serializable:

/** A string uniquely identifying a case of an enum */
def enumLabel: String

/** A number uniquely identifying a case of an enum */
def ordinal: Int

This file was deleted.

Loading