Skip to content

Commit 99d3f81

Browse files
committed
add EnumerateSingletons - listing all objects extending a sealed trait
1 parent f8a810e commit 99d3f81

File tree

4 files changed

+102
-14
lines changed

4 files changed

+102
-14
lines changed

project/Build.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,10 @@ object MyBuild extends Build{
1212
description := "Composable Records and type-indexed Maps for Scala",
1313
libraryDependencies ++= Seq(
1414
"org.scalatest" %% "scalatest" % "3.0.0-RC4" % "test",
15-
"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided"//,
16-
) ++
17-
CrossVersion.partialVersion(scalaVersion.value).collect{
18-
case (2, 10) =>
19-
Seq(
20-
"org.typelevel" %% "macro-compat" % "1.1.0",
21-
compilerPlugin("org.scalamacros" % "paradise" % "2.1.0" cross CrossVersion.full)
22-
)
23-
}.toSeq.flatten,
15+
"org.typelevel" %% "macro-compat" % "1.1.0",
16+
"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided",
17+
compilerPlugin("org.scalamacros" % "paradise" % "2.1.0" cross CrossVersion.full)
18+
),
2419
scalacOptions ++= Seq("-feature", "-deprecation", "-unchecked"),
2520
//scalacOptions ++= Seq("-Xprint:patmat", "-Xshow-phases"),
2621
testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oFD"),

readme.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@ http://cvogt.org/scala-extensions/
33

44
Contents:
55

6-
Type-level constraints (org.cvogt.constraints)
7-
- CaseClass and SingletonObject type classes
8-
- Comparisons: <:<, =:=, >:>, !=:=, !<:<, !>:>, e.g. String !=:= And
9-
- Boolean Algebra: True, False, ==, !, &&, ||, Implies, Xor
10-
- Subset tests: In, NotIn, e.g. Int NotIn (Any,AnyRef,AnyVal)
6+
Type-level helpers
7+
- EnumerateSingletons - listing all objects extending a sealed trait
118

129
Collection extensions (org.cvogt.collection)
1310
- distinctBy - remove duplicates by key
@@ -28,6 +25,12 @@ Debug (org.cvogt.scala.debug)
2825
Type safety
2926
- safe"..." alternative to s"..." that requires explicit toString conversions rather than implicit
3027

28+
Type-level constraints (org.cvogt.constraints)
29+
- CaseClass and SingletonObject type classes
30+
- Comparisons: <:<, =:=, >:>, !=:=, !<:<, !>:>, e.g. String !=:= And
31+
- Boolean Algebra: True, False, ==, !, &&, ||, Implies, Xor
32+
- Subset tests: In, NotIn, e.g. Int NotIn (Any,AnyRef,AnyVal)
33+
3134
Others
3235
- alternative `->` that works as constructor, extractor, type
3336

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package org.cvogt.scala
2+
import scala.reflect.macros.blackbox.Context
3+
import scala.language.experimental.macros
4+
import macrocompat.bundle
5+
6+
object EnumerateSingletons {
7+
/** singleton objects transitively extending the given class or trait */
8+
def apply[A]: Set[A] = macro EnumerateSingletonsMacros.enumerateSingletonsMacros[A]
9+
}
10+
11+
@bundle
12+
class EnumerateSingletonsMacros( val c: Context ) {
13+
import c.universe._
14+
def enumerateSingletonsMacros[T: c.WeakTypeTag]: Tree = {
15+
val T = weakTypeOf[T].typeSymbol.asClass
16+
val ( subs, verifiers ) = knownTransitiveSubclassesAndVerifiers( T )
17+
val (singletons, classes) = subs.partition( _.isModuleClass )
18+
val nonClosed = classes.filterNot(_.isSealed).filterNot(_.isFinal)
19+
if(nonClosed.nonEmpty){
20+
c.error(
21+
c.enclosingPosition,
22+
"EnumerateSingleton requires all transitive subclasses to be sealed or final. These are not: " ++ nonClosed.mkString(", ")
23+
)
24+
}
25+
val trees = singletons.map( _.module ).map( m => q"$m" )
26+
val tree = q"""{
27+
..$verifiers
28+
_root_.scala.collection.immutable.Set[$T](..$trees)
29+
}"""
30+
tree
31+
}
32+
33+
/** Generates a list of all singleton objects extending the given class directly or transitively. */
34+
private def knownTransitiveSubclassesAndVerifiers( sym: ClassSymbol ): ( Set[ClassSymbol], List[Tree] ) = {
35+
val direct = knownDirectSubclassesAndVerifier( sym )
36+
direct._1.map( knownTransitiveSubclassesAndVerifiers ).fold(
37+
direct
38+
)(
39+
( l, r ) => ( l._1 ++ r._1, l._2 ++ r._2 )
40+
)
41+
}
42+
43+
private def knownDirectSubclassesAndVerifier( T: ClassSymbol ): ( Set[ClassSymbol], List[Tree] ) = {
44+
val subs = T.knownDirectSubclasses
45+
46+
// hack to detect breakage of knownDirectSubclasses as suggested in
47+
// https://gitter.im/scala/scala/archives/2015/05/05 and
48+
// https://gist.github.com/retronym/639080041e3fecf58ba9
49+
val global = c.universe.asInstanceOf[scala.tools.nsc.Global]
50+
def checkSubsPostTyper = if ( subs != T.knownDirectSubclasses )
51+
c.error(
52+
c.macroApplication.pos,
53+
s"""No child classes found for $T. If there clearly are child classes,
54+
Try moving the call lower in the file, into a separate file, a sibbling package, a separate sbt sub project or else.
55+
This is caused by https://issues.scala-lang.org/browse/SI-7046 and can only be avoided by manually moving the call.
56+
It is triggered when a macro call happend in a place, where typechecking of $T hasn't been completed yet.
57+
Completion is required in order to find subclasses.
58+
"""
59+
)
60+
61+
val checkSubsPostTyperTypTree =
62+
new global.TypeTreeWithDeferredRefCheck()( () => { checkSubsPostTyper; global.TypeTree( global.NoType ) } ).asInstanceOf[TypTree]
63+
64+
val name = TypeName( c.freshName( "VerifyKnownDirectSubclassesPostTyper" ) )
65+
66+
(
67+
subs.map( _.asClass ).toSet,
68+
List( q"type ${name} = $checkSubsPostTyperTypTree" )
69+
)
70+
}
71+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.cvogt.scala.test
2+
3+
import org.cvogt.scala.EnumerateSingletons
4+
5+
import org.scalatest.FunSuite
6+
7+
sealed trait A
8+
case object B extends A
9+
sealed trait C extends A
10+
case object D extends C
11+
12+
class EnumerateSingletonsTest extends FunSuite {
13+
test( "works for hierarchies" ) {
14+
val s = EnumerateSingletons[A]
15+
assert(
16+
s === Set[A]( B, D )
17+
)
18+
}
19+
}

0 commit comments

Comments
 (0)