@@ -22,7 +22,7 @@ import scala.virtualization.lms.common.Functions
22
22
*/
23
23
trait Adts extends Functions {
24
24
25
- type Adt = scala.js.macroimpl. Adts .Adt
25
+ type Adt = AdtsImpl .Adt
26
26
27
27
def adt_construct [A : Manifest ](fields : (String , Rep [_])* ): Rep [A ]
28
28
def adt_select [A : Manifest , B : Manifest ](obj : Rep [A ], label : String ): Rep [B ]
@@ -34,11 +34,13 @@ trait Adts extends Functions {
34
34
* case class Point(x: Int, y: Int) extends Adt
35
35
* val Point = adt[Point]
36
36
* // Point is a staged smart constructor taking two Rep[Int] and returning a Rep[Point]
37
+ *
38
+ * val p1: Rep[Point] = Point(unit(1), unit(2))
37
39
* }}}
38
40
*
39
41
* @return a staged smart constructor for the data type T
40
42
*/
41
- def adt [T <: Adt ] = macro scala.js.macroimpl. Adts .adt[T ]
43
+ def adt [T <: Adt ] = macro AdtsImpl .adt[T ]
42
44
43
45
/**
44
46
* {{{
@@ -55,6 +57,234 @@ trait Adts extends Functions {
55
57
*
56
58
* @return an object with staged members for the type T
57
59
*/
58
- def adtOps [T <: Adt ](o : Rep [T ]) = macro scala.js.macroimpl.Adts .ops[T , Rep ]
60
+ def adtOps [T <: Adt ](o : Rep [T ]) = macro AdtsImpl .ops[T , Rep ]
61
+
62
+ }
63
+
64
+ object AdtsImpl {
65
+
66
+ import scala .reflect .macros .Context
67
+
68
+ trait Adt
69
+
70
+ def adt [U <: Adt : c.WeakTypeTag ](c : Context ) =
71
+ c.Expr [Any ](new Generator [c.type ](c).construct[U ])
72
+
73
+
74
+ def ops [U <: Adt : c.WeakTypeTag , R [_]](c : Context )(o : c.Expr [R [U ]]) =
75
+ c.Expr [Any ](new Generator [c.type ](c).ops(o))
76
+
77
+
78
+ class Generator [C <: Context ](val c : C ) {
79
+ import c .universe ._
80
+
81
+ /**
82
+ * @return The whole class hierarchy the type `A` belongs to. Works only with closed class hierarchies.
83
+ * The symbols are sorted by alphabetic order.
84
+ */
85
+ def wholeHierarchy [A <: Adt : WeakTypeTag ]: Seq [ClassSymbol ] = {
86
+
87
+ val rootClass : ClassSymbol =
88
+ weakTypeOf[A ].baseClasses
89
+ // Take up to `Adt` super type
90
+ .takeWhile(_.asClass.toType != typeOf[Adt ])
91
+ // Filter out type ancestors automatically added to case classes
92
+ .filterNot { s =>
93
+ val tpe = s.asClass.toType
94
+ tpe =:= typeOf[Equals ] || tpe =:= typeOf[Serializable ] || tpe =:= typeOf[java.io.Serializable ] || tpe =:= typeOf[Product ]
95
+ }.last.asClass // We know there is at least one element in the list because of `baseClasses`
96
+
97
+ def subHierarchy (base : ClassSymbol ): List [ClassSymbol ] = {
98
+ base.typeSignature // Needed before calling knownDirectSubclasses (SI-7046)
99
+ base.knownDirectSubclasses.foldLeft(List (base)) { (result, symbol) =>
100
+ val clazz = symbol.asClass
101
+ if (clazz.isCaseClass) clazz :: result
102
+ else if (clazz.isSealed && (clazz.isTrait || clazz.isAbstractClass)) subHierarchy(clazz) ++ result
103
+ else c.abort(c.enclosingPosition, " A class hierarchy may only contain case classes, sealed traits and sealed abstract classes" )
104
+ }
105
+ }
106
+
107
+ subHierarchy(rootClass)
108
+ .sortBy(_.name.decoded)
109
+ .ensuring(_.nonEmpty, s " Oops: whole hierarchy of $rootClass is empty " )
110
+ }
111
+
112
+ /**
113
+ * @return The class hierarchy of the type `A`, meaning, `A` and all its subclasses
114
+ */
115
+ def hierarchy [A <: Adt : WeakTypeTag ]: Seq [ClassSymbol ] =
116
+ wholeHierarchy[A ]
117
+ .filter(_.toType <:< weakTypeOf[A ])
118
+ .ensuring(_.nonEmpty, s " Oops: hierarchy of ${weakTypeOf[A ].typeSymbol.asClass} is empty! (whole hierarchy is: ${wholeHierarchy[A ]}) " )
119
+
120
+ case class Member (name : String , term : TermName , tpe : Type )
121
+
122
+ object Member {
123
+ def apply (symbol : Symbol ) = {
124
+ // Trim because case classes members introduce a trailing space
125
+ val nameStr = symbol.name.decoded.trim
126
+ new Member (nameStr, newTermName(nameStr), symbol.typeSignature)
127
+ }
128
+ }
129
+
130
+ /** @return The members of the type `tpe` */
131
+ def listMembers (tpe : Type ): List [Member ] =
132
+ tpe.typeSymbol.typeSignature.declarations.toList.collect { case x : TermSymbol if x.isVal && x.isCaseAccessor => Member (x) }
133
+
134
+ /**
135
+ * Expands to a value providing staged operations on algebraic data types.
136
+ *
137
+ * Applied to an object `r` of a record type `R`, it expands to the following:
138
+ *
139
+ * {{{
140
+ * class $1 {
141
+ * // `f1`, `f2`, ... are fields of `r`
142
+ * def f1: Rep[F1] = ...
143
+ * def f2: Rep[F2] = ...
144
+ * def copy(f1: Rep[F1] = r.f1, f2: Rep[F2] = r.f2, ...): Rep[R] = ...
145
+ * }
146
+ * new $1
147
+ * }}}
148
+ *
149
+ * Applied to an object `s` of a sum type `S`, it expands to the following:
150
+ *
151
+ * {{{
152
+ * class $1 {
153
+ * def === (that: Rep[S]): Rep[Boolean] = ...
154
+ * // `R1`, `R2`, ... are variants of `S`
155
+ * def fold[A](r1: Rep[R1] => Rep[A], r2: Rep[R2] => Rep[A], ...): Rep[A] = ...
156
+ * }
157
+ * new $1
158
+ * }}}
159
+ */
160
+ // TODO Simplify the expansion
161
+ def ops [U <: Adt : c.WeakTypeTag , R [_]](obj : c.Expr [R [U ]]) = {
162
+ val anon = newTypeName(c.fresh)
163
+ val wrapper = newTypeName(c.fresh)
164
+ val ctor = newTermName(c.fresh)
165
+
166
+ val U = weakTypeOf[U ]
167
+ val members = listMembers(U )
168
+ if (! U .typeSymbol.isClass) {
169
+ c.abort(c.enclosingPosition, s " $U must be a sealed trait, an abstract class or a case class " )
170
+ }
171
+ val typeSymbol = U .typeSymbol.asClass
172
+ if (! (typeSymbol.isCaseClass || (typeSymbol.isSealed && (typeSymbol.isTrait || typeSymbol.isAbstractClass)))) {
173
+ c.abort(c.enclosingPosition, s " $U must be a sealed trait, an abstract class or a case class " )
174
+ }
175
+
176
+ val objName = typeSymbol.name
177
+
178
+ val defGetters = for (member <- members) yield q " def ${member.term}: Rep[ ${member.tpe}] = adt_select[ $U, ${member.tpe}]( $obj , ${member.name}) "
179
+
180
+ val paramsCopy = for (member <- members) yield q " val ${member.term}: Rep[ ${member.tpe}] = adt_select[ $U, ${member.tpe}]( $obj , ${member.name}) "
181
+
182
+ val paramsConstruct = for (member <- members) yield q " ${member.term}"
183
+
184
+ val defCopy = q """
185
+ def copy(.. $paramsCopy): Rep[ $objName] = $ctor(.. $paramsConstruct)
186
+ """
187
+
188
+ val variants = U .baseClasses.drop(1 ).filter(bc => bc.asClass.toType <:< typeOf[Adt ] && bc.asClass.toType != typeOf[Adt ])
189
+
190
+ // TODO Review this code
191
+ def getFields (params : Seq [Member ], root : String , list : List [Tree ]): List [Tree ] = params match {
192
+ case Nil =>
193
+ if (! variants.isEmpty){
194
+ val variant = root+ " $variant"
195
+ q " $variant" :: list
196
+ }else {
197
+ list
198
+ }
199
+ case param +: tail =>
200
+ if (param.tpe <:< typeOf[Adt ]) {
201
+ val paramMembers = listMembers(param.tpe)
202
+ val l = getFields(paramMembers, root + param.name + " ." , list)
203
+ getFields(tail, root, l)
204
+ } else {
205
+ val name = root + param.name
206
+ getFields(tail, root, q """ $name""" :: list)
207
+ }
208
+ }
209
+
210
+ val fieldsObj = getFields(members, " " , List ())
211
+
212
+ val defEqual =
213
+ q """
214
+ def === (bis: Rep[ $objName]): Rep[Boolean] = {
215
+ adt_equal( $obj, bis, Seq(.. $fieldsObj), Seq(.. $fieldsObj))
216
+ }
217
+ """
218
+
219
+ val variants2 = wholeHierarchy[U ].filter(_.isCaseClass).map(s => s -> newTermName(c.fresh()))
220
+
221
+ val paramsFold = for ((param, symbol) <- variants2) yield q " val $symbol: (Rep[ $param] => Rep[A]) "
222
+
223
+ val paramsFoldLambda = for ((_, symbol) <- variants2) yield q " doLambda( $symbol) "
224
+
225
+ val paramsFoldName = for (param <- paramsFoldLambda) yield q " $param.asInstanceOf[Rep[ $U => A]] "
226
+
227
+ val defFold = q """ def fold[A : Manifest](.. $paramsFold): Rep[A] = {
228
+ adt_fold( $obj, Seq(.. $paramsFoldName))
229
+ }
230
+ """
231
+
232
+ if (typeSymbol.isCaseClass) {
233
+ q """
234
+ class $anon {
235
+ val $ctor = adt[ $objName]
236
+ .. $defGetters
237
+ $defCopy
238
+ $defEqual
239
+ }
240
+ class $wrapper extends $anon{}
241
+ new $wrapper
242
+ """
243
+ } else {
244
+ q """
245
+ class $anon {
246
+ $defFold
247
+ $defEqual
248
+ }
249
+ class $wrapper extends $anon{}
250
+ new $wrapper
251
+ """
252
+ }
253
+ }
59
254
60
- }
255
+ /**
256
+ * Expands to a staged smart constructor.
257
+ *
258
+ * Applied to a record type (case class) `C` it expands to the following smart constructor:
259
+ *
260
+ * {{{
261
+ * (f1: Rep[F1], f2: Rep[F2], ...) => ...: Rep[C]
262
+ * }}}
263
+ */
264
+ // TODO Simplify the expansion
265
+ def construct [U <: Adt : c.WeakTypeTag ]: c.Tree = {
266
+ val U = weakTypeOf[U ]
267
+ if (U .typeSymbol.asClass.isCaseClass) {
268
+ val members = listMembers(U )
269
+ val objName = U .typeSymbol.name
270
+ val paramsDef = for (member <- members) yield q " val ${member.term}: Rep[ ${member.tpe}] "
271
+ val paramsConstruct = for (member <- members) yield q " ${member.name} -> ${member.term}"
272
+ val paramsType = for (member <- members) yield tq " Rep[ ${member.tpe}] "
273
+ val allParams = {
274
+ val variants = wholeHierarchy[U ].filter(_.isCaseClass)
275
+ if (variants.size == 1 ) paramsConstruct else {
276
+ val variant = variants.indexOf(U .typeSymbol)
277
+ paramsConstruct :+ q """ " $$ variant" -> unit( $variant) """
278
+ }
279
+ }
280
+ q """
281
+ new ${newTypeName(" Function" + paramsType.length)}[.. $paramsType, Rep[ $objName]] {
282
+ def apply(.. $paramsDef) = adt_construct[ $objName](.. $allParams)
283
+ }
284
+ """
285
+ } else {
286
+ c.abort(c.enclosingPosition, s " $U must be a case class " )
287
+ }
288
+ }
289
+ }
290
+ }
0 commit comments