diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 788e8dc61b..397d84cca2 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -5361,6 +5361,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) case UNWRAP_FROM_THROWABLE => // js.special.unwrapFromThrowable(arg) js.UnwrapFromThrowable(genArgs1) + + case LINKTIME_IF => + // linkingInfo.linkTimeIf(cond, thenp, elsep) + assert(args.size == 3, + s"Expected exactly 3 arguments for JS primitive $code but got " + + s"${args.size} at $pos") + val condp = genLinkTimeTree(args(0)) + val thenp = genExpr(args(1)) + val elsep = genExpr(args(2)) + js.LinkTimeIf(condp, thenp, elsep)(toIRType(tree.tpe)) } } @@ -6827,8 +6837,103 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.ApplyStatic(js.ApplyFlags.empty, className, method, Nil)(toIRType(sym.tpe)) } } + + private def genLinkTimeTree(cond: Tree)( + implicit pos: Position): js.LinkTimeTree = { + import js.LinkTimeOp._ + val dummy = js.LinkTimeTree.Property("dummy", toIRType(cond.tpe)) + cond match { + case Literal(Constant(b: Boolean)) => + js.LinkTimeTree.BooleanConst(b) + + case Literal(Constant(i: Int)) => + js.LinkTimeTree.IntConst(i) + + case Literal(_) => + reporter.error(cond.pos, + s"Invalid literal $cond inside linkTimeIf. " + + "Only boolean and int values can be used in linkTimeIf.") + dummy + + case Ident(name) => + reporter.error(cond.pos, + s"Invalid identifier $name inside linkTimeIf. " + + "Only @linkTimeProperty annotated values can be used in linkTimeIf.") + dummy + + // !x + case Apply(Select(t, nme.UNARY_!), Nil) if cond.symbol == definitions.Boolean_not => + val lt = genLinkTimeTree(t) + js.LinkTimeTree.BinaryOp(Boolean_==, lt, js.LinkTimeTree.BooleanConst(false)) + + // if(foo()) (...) + case Apply(prop, Nil) => + getLinkTimeProperty(prop).getOrElse { + reporter.error(prop.pos, + s"Invalid identifier inside linkTimeIf. " + + "Only @linkTimeProperty annotated values can be used in linkTimeIf.") + dummy + } + + // if(lhs rhs) (...) + case Apply(Select(cond1, comp), List(cond2)) => + val tpe = toIRType(cond.tpe) + val c1 = genLinkTimeTree(cond1) + val c2 = genLinkTimeTree(cond2) + val dummyOp = -1 + val op: Code = + if (c1.tpe == jstpe.IntType) { + comp match { + case nme.EQ => Int_== + case nme.NE => Int_!= + case nme.GT => Int_> + case nme.GE => Int_>= + case nme.LT => Int_< + case nme.LE => Int_<= + case _ => + reporter.error(cond.pos, + s"Invalid operation '$comp' inside linkTimeIf. " + + "Only '==', '!=', '>', '>=', '<', '<=' " + + "operations are allowed for integer values in linkTimeIf.") + dummyOp + } + } else if (c1.tpe == jstpe.BooleanType) { + comp match { + case nme.EQ => Boolean_== + case nme.NE => Boolean_!= + case nme.ZAND => Boolean_&& + case nme.ZOR => Boolean_|| + case _ => + reporter.error(cond.pos, + s"Invalid operation '$comp' inside linkTimeIf. " + + "Only '==', '!=', '&&', and '||' operations are allowed for boolean values in linkTimeIf.") + dummyOp + } + } else { + dummyOp + } + if (op == dummyOp) dummy + else js.LinkTimeTree.BinaryOp(op, c1, c2) + + case t => + reporter.error(t.pos, + s"Only @linkTimeProperty annotated values, int and boolean constants, " + + "and binary operations are allowd in linkTimeIf.") + dummy + } + } } + private def getLinkTimeProperty(tree: Tree): Option[js.LinkTimeTree.Property] = { + tree.symbol.getAnnotation(LinkTimePropertyAnnotation) + .flatMap(_.args.headOption) + .flatMap { + case Literal(Constant(v: String)) => + Some(js.LinkTimeTree.Property(v, toIRType(tree.symbol.tpe.resultType))(tree.pos)) + case _ => None + } + } + private lazy val hasNewCollections = !scala.util.Properties.versionNumberString.startsWith("2.12.") diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala index 43fa33aedd..37f1e25f63 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala @@ -72,6 +72,8 @@ trait JSDefinitions { lazy val JSGlobalScopeAnnotation = getRequiredClass("scala.scalajs.js.annotation.JSGlobalScope") lazy val JSOperatorAnnotation = getRequiredClass("scala.scalajs.js.annotation.JSOperator") + lazy val LinkTimePropertyAnnotation = getRequiredClass("scala.scalajs.js.annotation.linkTimeProperty") + lazy val JSImportNamespaceObject = getRequiredModule("scala.scalajs.js.annotation.JSImport.Namespace") lazy val ExposedJSMemberAnnot = getRequiredClass("scala.scalajs.js.annotation.internal.ExposedJSMember") @@ -128,6 +130,9 @@ trait JSDefinitions { lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk") lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply) + lazy val LinkingInfoClass = getRequiredModule("scala.scalajs.LinkingInfo") + lazy val LinkingInfoClass_linkTimeIf = getMemberMethod(LinkingInfoClass, newTermName("linkTimeIf")) + lazy val Tuple2_apply = getMemberMethod(TupleClass(2).companionModule, nme.apply) // This is a def, since similar symbols (arrayUpdateMethod, etc.) are in runDefinitions diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala index df5ff293db..42e331459b 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala @@ -70,7 +70,9 @@ abstract class JSPrimitives { final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger - final val LastJSPrimitiveCode = DEBUGGER + final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf + + final val LastJSPrimitiveCode = LINKTIME_IF /** Initialize the map of primitive methods (for GenJSCode) */ def init(): Unit = initWithPrimitives(addPrimitive) @@ -123,6 +125,8 @@ abstract class JSPrimitives { addPrimitive(Special_wrapAsThrowable, WRAP_AS_THROWABLE) addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE) addPrimitive(Special_debugger, DEBUGGER) + + addPrimitive(LinkingInfoClass_linkTimeIf, LINKTIME_IF) } def isJavaScriptPrimitive(code: Int): Boolean = diff --git a/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala new file mode 100644 index 0000000000..e6554611fe --- /dev/null +++ b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala @@ -0,0 +1,114 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.nscplugin.test + +import util._ + +import org.junit.Test +import org.junit.Assert._ + +class LinkTimeIfTest extends TestHelpers { + override def preamble: String = "import scala.scalajs.LinkingInfo._" + + // scalastyle:off line.size.limit + @Test + def linkTimeErrorInvalidOp(): Unit = { + """ + object A { + def foo = + linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + } + """ hasErrors + """ + |newSource1.scala:4: error: Invalid operation '$plus' inside linkTimeIf. Only '==', '!=', '>', '>=', '<', '<=' operations are allowed for integer values in linkTimeIf. + | linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + | ^ + """ + + """ + object A { + def foo = + linkTimeIf(productionMode | true) { } { } + } + """ hasErrors + """ + |newSource1.scala:4: error: Invalid operation '$bar' inside linkTimeIf. Only '==', '!=', '&&', and '||' operations are allowed for boolean values in linkTimeIf. + | linkTimeIf(productionMode | true) { } { } + | ^ + """ + } + + @Test + def linkTimeErrorInvalidEntities(): Unit = { + """ + object A { + def foo(x: String) = { + val bar = 1 + linkTimeIf(bar == 0) { } { } + } + } + """ hasErrors + """ + |newSource1.scala:5: error: Invalid identifier bar inside linkTimeIf. Only @linkTimeProperty annotated values can be used in linkTimeIf. + | linkTimeIf(bar == 0) { } { } + | ^ + """ + + """ + object A { + def foo(x: String) = + linkTimeIf("foo" == x) { } { } + } + """ hasErrors + """ + |newSource1.scala:4: error: Invalid literal "foo" inside linkTimeIf. Only boolean and int values can be used in linkTimeIf. + | linkTimeIf("foo" == x) { } { } + | ^ + |newSource1.scala:4: error: Invalid identifier x inside linkTimeIf. Only @linkTimeProperty annotated values can be used in linkTimeIf. + | linkTimeIf("foo" == x) { } { } + | ^ + """ + + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(bar || !bar) { } { } + } + """ hasErrors + """ + |newSource1.scala:5: error: Invalid identifier inside linkTimeIf. Only @linkTimeProperty annotated values can be used in linkTimeIf. + | linkTimeIf(bar || !bar) { } { } + | ^ + |newSource1.scala:5: error: Invalid identifier inside linkTimeIf. Only @linkTimeProperty annotated values can be used in linkTimeIf. + | linkTimeIf(bar || !bar) { } { } + | ^ + """ + } + + @Test + def linkTimeCondInvalidTree(): Unit = { + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(if(bar) true else false) { } { } + } + """ hasErrors + """ + |newSource1.scala:5: error: Only @linkTimeProperty annotated values, int and boolean constants, and binary operations are allowd in linkTimeIf. + | linkTimeIf(if(bar) true else false) { } { } + | ^ + """ + } +} diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index b1f69595a3..a9929ec0d8 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -206,6 +206,13 @@ object Hashers { mixTree(elsep) mixType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + mixTag(TagLinkTimeIf) + mixLinkTimeTree(cond) + mixTree(thenp) + mixTree(elsep) + mixType(tree.tpe) + case While(cond, body) => mixTag(TagWhile) mixTree(cond) @@ -700,6 +707,27 @@ object Hashers { digestStream.writeInt(pos.column) } + private def mixLinkTimeTree(cond: LinkTimeTree): Unit = { + cond match { + case LinkTimeTree.BinaryOp(op, lhs, rhs) => + mixTag(TagLinkTimeTreeBinary) + digestStream.writeInt(op) + mixLinkTimeTree(lhs) + mixLinkTimeTree(rhs) + case LinkTimeTree.Property(name, tpe) => + mixTag(TagLinkTimeProperty) + digestStream.writeUTF(name) + mixType(tpe) + case LinkTimeTree.BooleanConst(v) => + mixTag(TagLinkTimeBooleanConst) + digestStream.writeBoolean(v) + case LinkTimeTree.IntConst(v) => + mixTag(TagLinkTimeIntConst) + digestStream.writeInt(v) + } + mixPos(cond.pos) + } + @inline final def mixTag(tag: Int): Unit = mixInt(tag) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index 318d69355f..97f57fdce7 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -134,6 +134,7 @@ object Printers { case node: MemberDef => print(node) case node: JSConstructorBody => printBlock(node.allStats) case node: TopLevelExportDef => print(node) + case node: LinkTimeTree => print(node) } } @@ -218,6 +219,15 @@ object Printers { printBlock(elsep) } + case LinkTimeIf(cond, thenp, elsep) => + print("linkTimeIf (") + print(cond) + print(") ") + + printBlock(thenp) + print(" else ") + printBlock(elsep) + case While(cond, body) => print("while (") print(cond) @@ -1181,6 +1191,38 @@ object Printers { } } + def print(cond: LinkTimeTree): Unit = { + import LinkTimeOp._ + cond match { + case LinkTimeTree.BinaryOp(op, lhs, rhs) => + print(lhs) + print(" ") + print(op match { + case Boolean_== => "==" + case Boolean_!= => "!=" + case Boolean_|| => "||" + case Boolean_&& => "&&" + + case Int_== => "==" + case Int_!= => "!=" + case Int_< => "<" + case Int_<= => "<=" + case Int_> => ">" + case Int_>= => ">=" + }) + print(" ") + print(rhs) + case LinkTimeTree.BooleanConst(v) => + if (v) print("true") else print("false") + case LinkTimeTree.IntConst(v) => + print(v.toString) + case LinkTimeTree.Property(name, _) => + print("prop[") + print(name) + print("]") + } + } + def print(s: String): Unit = out.write(s) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index bad2b82fa5..d6f0b7920e 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -274,6 +274,11 @@ object Serializers { writeTree(cond); writeTree(thenp); writeTree(elsep) writeType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + writeTagAndPos(TagLinkTimeIf) + writeLinkTimeTree(cond); writeTree(thenp); writeTree(elsep) + writeType(tree.tpe) + case While(cond, body) => writeTagAndPos(TagWhile) writeTree(cond); writeTree(body) @@ -1006,6 +1011,33 @@ object Serializers { buffer.writeInt(strings.size) strings.foreach(writeString) } + + def writeLinkTimeTree(cond: LinkTimeTree): Unit = { + import buffer._ + + def writeTagAndPos(tag: Int) = { + writeByte(tag) + writePosition(cond.pos) + } + + cond match { + case LinkTimeTree.Property(name, tpe) => + writeTagAndPos(TagLinkTimeProperty) + writeString(name) + writeType(tpe) + case LinkTimeTree.BooleanConst(v) => + writeTagAndPos(TagLinkTimeBooleanConst) + writeBoolean(v) + case LinkTimeTree.IntConst(v) => + writeTagAndPos(TagLinkTimeIntConst) + writeInt(v) + case LinkTimeTree.BinaryOp(op, lhs, rhs) => + writeTagAndPos(TagLinkTimeTreeBinary) + writeByte(op) + writeLinkTimeTree(lhs) + writeLinkTimeTree(rhs) + } + } } private final class Deserializer(buf: ByteBuffer) { @@ -1147,6 +1179,14 @@ object Serializers { case TagReturn => Return(readTree(), readLabelIdent()) case TagIf => If(readTree(), readTree(), readTree())(readType()) + + case TagLinkTimeIf => + val linkTimeCond = readLinkTimeTree() + val thenp = readTree() + val elsep = readTree() + val tpe = readType() + LinkTimeIf(linkTimeCond, thenp, elsep)(tpe) + case TagWhile => While(readTree(), readTree()) case TagDoWhile => @@ -2116,6 +2156,25 @@ object Serializers { res } + + private def readLinkTimeTree(): LinkTimeTree = { + val tag = readByte() + implicit val pos = readPosition() + tag match { + case TagLinkTimeTreeBinary => + LinkTimeTree.BinaryOp( + readByte(), + readLinkTimeTree(), + readLinkTimeTree() + ) + case TagLinkTimeProperty => + LinkTimeTree.Property(readString(), readType()) + case TagLinkTimeIntConst => + LinkTimeTree.IntConst(readInt()) + case TagLinkTimeBooleanConst => + LinkTimeTree.BooleanConst(readBoolean()) + } + } } /** Hacks for backwards compatible deserializing. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index 3c3162245b..1e54e7e806 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -127,6 +127,10 @@ private[ir] object Tags { final val TagWrapAsThrowable = TagJSNewTarget + 1 final val TagUnwrapFromThrowable = TagWrapAsThrowable + 1 + // New in 1.17 + + final val TagLinkTimeIf = TagUnwrapFromThrowable + 1 + // Tags for member defs final val TagFieldDef = 1 @@ -199,4 +203,11 @@ private[ir] object Tags { final val TagJSNativeLoadSpecImport = TagJSNativeLoadSpecGlobal + 1 final val TagJSNativeLoadSpecImportWithGlobalFallback = TagJSNativeLoadSpecImport + 1 + // Tags for LinkTimeTree + + final val TagLinkTimeProperty = 0 + final val TagLinkTimeBooleanConst = TagLinkTimeProperty + 1 + final val TagLinkTimeIntConst = TagLinkTimeBooleanConst + 1 + final val TagLinkTimeTreeBinary = TagLinkTimeIntConst + 1 + } diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index 6d30327786..46b1c6a941 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -64,6 +64,10 @@ object Transformers { If(transformExpr(cond), transform(thenp, isStat), transform(elsep, isStat))(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + LinkTimeIf(cond, transform(thenp, isStat), + transform(elsep, isStat))(tree.tpe) + case While(cond, body) => While(transformExpr(cond), transformStat(body)) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala index 8a8909cdce..a5f07934e6 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala @@ -48,6 +48,10 @@ object Traversers { traverse(thenp) traverse(elsep) + case LinkTimeIf(_, thenp, elsep) => + traverse(thenp) + traverse(elsep) + case While(cond, body) => traverse(cond) traverse(body) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index 411f6b9a95..43d751dd9c 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -168,6 +168,48 @@ object Trees { sealed case class If(cond: Tree, thenp: Tree, elsep: Tree)(val tpe: Type)( implicit val pos: Position) extends Tree + sealed case class LinkTimeIf(cond: LinkTimeTree, thenp: Tree, + elsep: Tree)(val tpe: Type)(implicit val pos: Position) extends Tree + + sealed abstract class LinkTimeTree extends IRNode { + val pos: Position + val tpe: Type + } + + object LinkTimeTree { + final case class BinaryOp(op: LinkTimeOp.Code, lhs: LinkTimeTree, rhs: LinkTimeTree)( + implicit val pos: Position) extends LinkTimeTree { + val tpe = BooleanType + } + + final case class Property(name: String, tpe: Type)(implicit val pos: Position) + extends LinkTimeTree + + final case class IntConst(v: Int)(implicit val pos: Position) extends LinkTimeTree { + val tpe = IntType + } + + final case class BooleanConst(v: Boolean)(implicit val pos: Position) extends LinkTimeTree { + val tpe = BooleanType + } + } + + object LinkTimeOp { + type Code = Int + + final val Boolean_== = 1 + final val Boolean_!= = 2 + final val Boolean_&& = 3 + final val Boolean_|| = 4 + + final val Int_== = 5 + final val Int_!= = 6 + final val Int_< = 7 + final val Int_<= = 8 + final val Int_> = 9 + final val Int_>= = 10 + } + sealed case class While(cond: Tree, body: Tree)( implicit val pos: Position) extends Tree { // cannot be in expression position, unless it is infinite diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index 590a24c209..f396bff9b8 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -194,6 +194,46 @@ class PrintersTest { If(ref("x", BooleanType), ref("y", BooleanType), b(false))(BooleanType)) } + @Test def printLinkTimeIf(): Unit = { + assertPrintEquals( + """ + |linkTimeIf (prop[foo] == 1) { + | 1 + |} else { + | 2 + |} + """, + LinkTimeIf( + LinkTimeTree.BinaryOp( + LinkTimeOp.Int_==, + LinkTimeTree.Property("foo", IntType), + LinkTimeTree.IntConst(1) + ), + i(1), + i(2) + )(IntType) + ) + + assertPrintEquals( + """ + |linkTimeIf (prop[foo] != true) { + | 1 + |} else { + | 2 + |} + """, + LinkTimeIf( + LinkTimeTree.BinaryOp( + LinkTimeOp.Boolean_!=, + LinkTimeTree.Property("foo", BooleanType), + LinkTimeTree.BooleanConst(true) + ), + i(1), + i(2) + )(IntType) + ) + } + @Test def printWhile(): Unit = { assertPrintEquals( """ diff --git a/library/src/main/scala/scala/scalajs/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/LinkingInfo.scala index bf1bfa9c00..1d37634696 100644 --- a/library/src/main/scala/scala/scalajs/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/LinkingInfo.scala @@ -12,6 +12,8 @@ package scala.scalajs +import scala.scalajs.js.annotation.linkTimeProperty + object LinkingInfo { import scala.scalajs.runtime.linkingInfo @@ -44,7 +46,7 @@ object LinkingInfo { * * @see [[developmentMode]] */ - @inline + @inline @linkTimeProperty("core/productionMode") def productionMode: Boolean = linkingInfo.productionMode @@ -122,7 +124,7 @@ object LinkingInfo { * useES2018Feature() * }}} */ - @inline + @inline @linkTimeProperty("core/esVersion") def esVersion: Int = linkingInfo.esVersion @@ -326,4 +328,33 @@ object LinkingInfo { */ final val ES2021 = 12 } + + /** Link-time conditional branching. + * + * The `linkTimeIf` expression will be evaluated at link-time, and only the + * branch that needs to be executed will be linked. The other branch will be + * removed during the linking process. + * + * The condition `cond` can be constructed using: + * - Symbols annotated with `@linkTimeProperty` + * - Integer or boolean constants + * - Binary operators that return a boolean value + * + * Example usage: + * {{{ + * def pow(x: Double, y: Double): Double = + * linkTimeIf(esVersion >= ESVersion.ES2016) { + * (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]) + * .asInstanceOf[Double] + * } { + * Math.pow(x, y) + * } + * }}} + * + * If `LinkingInfo.esVersion` is `ESVersion.ES2016` or later, + * the first branch will be linked and the second branch will be removed, + * regardless of whether the optimizer is enabled. + */ + def linkTimeIf[T](cond: Boolean)(thenp: T)(elsep: T): T = + throw new Error("stub") } diff --git a/library/src/main/scala/scala/scalajs/js/annotation/linkTimeProperty.scala b/library/src/main/scala/scala/scalajs/js/annotation/linkTimeProperty.scala new file mode 100644 index 0000000000..fd722bf6ac --- /dev/null +++ b/library/src/main/scala/scala/scalajs/js/annotation/linkTimeProperty.scala @@ -0,0 +1,29 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.scalajs.js.annotation + +/** Specifies that the annotated entity can be replaced by a value known at linktime. + * + * When an entity is annotated with `@linkTimeProperty`, it can be used in the + * condition of `LinkingInfo.linkTimeIf`. During linking, the annotated entity + * will be replaced by a value determined at link time. + * + * The link-time value is resolved using the `name` parameter of the annotation + * by the `org.scalajs.linker.standard.LinkTimeProperties`. + * + * @param name The name used to resolve the link-time value. + * + * @see [[LinkingInfo.linkTimeIf]] + * @see [[LinkTimeProperties]] + */ +private[scalajs] class linkTimeProperty(name: String) extends scala.annotation.StaticAnnotation diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala index 3931d2ab58..960624f2b2 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala @@ -49,7 +49,8 @@ final class Analyzer(config: CommonPhaseConfig, initial: Boolean, new InfoLoader(irLoader, if (!checkIR) InfoLoader.NoIRCheck else if (initial) InfoLoader.InitialIRCheck - else InfoLoader.InternalIRCheck + else InfoLoader.InternalIRCheck, + config.coreSpec.linkTimeProperties ) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala index 7eeb5d197b..5670a108e9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala @@ -25,11 +25,13 @@ import org.scalajs.logging._ import org.scalajs.linker.checker.ClassDefChecker import org.scalajs.linker.frontend.IRLoader import org.scalajs.linker.interface.LinkingException +import org.scalajs.linker.standard.LinkTimeProperties import org.scalajs.linker.CollectionsCompat.MutableMapCompatOps import Platform.emptyThreadSafeMap -private[analyzer] final class InfoLoader(irLoader: IRLoader, irCheckMode: InfoLoader.IRCheckMode) { +private[analyzer] final class InfoLoader(irLoader: IRLoader, irCheckMode: InfoLoader.IRCheckMode, + linkTimeProperties: LinkTimeProperties) { private var logger: Logger = _ private val cache = emptyThreadSafeMap[ClassName, InfoLoader.ClassInfoCache] @@ -45,7 +47,7 @@ private[analyzer] final class InfoLoader(irLoader: IRLoader, irCheckMode: InfoLo if (irLoader.classExists(className)) { val infoCache = cache.getOrElseUpdate(className, new InfoLoader.ClassInfoCache(className, irLoader, irCheckMode)) - Some(infoCache.loadInfo(logger)) + Some(infoCache.loadInfo(logger, linkTimeProperties)) } else { None } @@ -75,7 +77,8 @@ private[analyzer] object InfoLoader { private var prevJSCtorInfo: Option[Infos.ReachabilityInfo] = None private var prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo] = Nil - def loadInfo(logger: Logger)(implicit ec: ExecutionContext): Future[Infos.ClassInfo] = synchronized { + def loadInfo(logger: Logger, linkTimeProperties: LinkTimeProperties)( + implicit ec: ExecutionContext): Future[Infos.ClassInfo] = synchronized { /* If the cache was already used in this run, the classDef and info are * already correct, no matter what the versions say. */ @@ -92,7 +95,8 @@ private[analyzer] object InfoLoader { case InfoLoader.InitialIRCheck => val errorCount = ClassDefChecker.check(tree, - postBaseLinker = false, postOptimizer = false, logger) + postBaseLinker = false, postOptimizer = false, + logger, linkTimeProperties) if (errorCount != 0) { throw new LinkingException( s"There were $errorCount ClassDef checking errors.") @@ -100,7 +104,8 @@ private[analyzer] object InfoLoader { case InfoLoader.InternalIRCheck => val errorCount = ClassDefChecker.check(tree, - postBaseLinker = true, postOptimizer = true, logger) + postBaseLinker = true, postOptimizer = true, + logger, linkTimeProperties) if (errorCount != 0) { throw new LinkingException( s"There were $errorCount ClassDef checking errors after optimizing. " + @@ -108,7 +113,7 @@ private[analyzer] object InfoLoader { } } - generateInfos(tree) + generateInfos(tree, linkTimeProperties) } } } @@ -116,13 +121,13 @@ private[analyzer] object InfoLoader { info } - private def generateInfos(classDef: ClassDef): Infos.ClassInfo = { + private def generateInfos(classDef: ClassDef, linkTimeProperties: LinkTimeProperties): Infos.ClassInfo = { val referencedFieldClasses = Infos.genReferencedFieldClasses(classDef.fields) - prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos) - prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo) + prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos, linkTimeProperties) + prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo, linkTimeProperties) prevJSMethodPropDefInfos = - genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos) + genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos, linkTimeProperties) val exportedMembers = prevJSCtorInfo.toList ::: prevJSMethodPropDefInfos @@ -130,7 +135,7 @@ private[analyzer] object InfoLoader { * and usually quite small when they exist. */ val topLevelExports = classDef.topLevelExportDefs - .map(Infos.generateTopLevelExportInfo(classDef.name.name, _)) + .map(Infos.generateTopLevelExportInfo(classDef.name.name, _, linkTimeProperties)) val jsNativeMembers = classDef.jsNativeMembers .map(m => m.name.name -> m.jsNativeLoadSpec).toMap @@ -150,7 +155,7 @@ private[analyzer] object InfoLoader { } private def genMethodInfos(methods: List[MethodDef], - prevMethodInfos: MethodInfos): MethodInfos = { + prevMethodInfos: MethodInfos, linkTimeProperties: LinkTimeProperties): MethodInfos = { val builders = Array.fill(MemberNamespace.Count)(Map.newBuilder[MethodName, Infos.MethodInfo]) @@ -158,7 +163,7 @@ private[analyzer] object InfoLoader { val info = prevMethodInfos(method.flags.namespace.ordinal) .get(method.methodName) .filter(_.version.sameVersion(method.version)) - .getOrElse(Infos.generateMethodInfo(method)) + .getOrElse(Infos.generateMethodInfo(method, linkTimeProperties)) builders(method.flags.namespace.ordinal) += method.methodName -> info } @@ -167,16 +172,18 @@ private[analyzer] object InfoLoader { } private def genJSCtorInfo(jsCtor: Option[JSConstructorDef], - prevJSCtorInfo: Option[Infos.ReachabilityInfo]): Option[Infos.ReachabilityInfo] = { + prevJSCtorInfo: Option[Infos.ReachabilityInfo], + linkTimeProperties: LinkTimeProperties): Option[Infos.ReachabilityInfo] = { jsCtor.map { ctor => prevJSCtorInfo .filter(_.version.sameVersion(ctor.version)) - .getOrElse(Infos.generateJSConstructorInfo(ctor)) + .getOrElse(Infos.generateJSConstructorInfo(ctor, linkTimeProperties)) } } private def genJSMethodPropDefInfos(jsMethodProps: List[JSMethodPropDef], - prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo]): List[Infos.ReachabilityInfo] = { + prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo], + linkTimeProperties: LinkTimeProperties): List[Infos.ReachabilityInfo] = { /* For JS method and property definitions, we use their index in the list of * `linkedClass.exportedMembers` as their identity. We cannot use their name * because the name itself is a `Tree`. @@ -190,13 +197,13 @@ private[analyzer] object InfoLoader { if (prevJSMethodPropDefInfos.size != jsMethodProps.size) { // Regenerate everything. - jsMethodProps.map(Infos.generateJSMethodPropDefInfo(_)) + jsMethodProps.map(Infos.generateJSMethodPropDefInfo(_, linkTimeProperties)) } else { for { (prevInfo, member) <- prevJSMethodPropDefInfos.zip(jsMethodProps) } yield { if (prevInfo.version.sameVersion(member.version)) prevInfo - else Infos.generateJSMethodPropDefInfo(member) + else Infos.generateJSMethodPropDefInfo(member, linkTimeProperties) } } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index a6ed3f6a6e..7c6e2443bd 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -24,6 +24,7 @@ import org.scalajs.ir.Version import org.scalajs.linker.backend.emitter.Transients._ import org.scalajs.linker.standard.LinkedTopLevelExport import org.scalajs.linker.standard.ModuleSet.ModuleID +import org.scalajs.linker.standard.LinkTimeProperties object Infos { @@ -486,43 +487,44 @@ object Infos { /** Generates the [[MethodInfo]] of a * [[org.scalajs.ir.Trees.MethodDef Trees.MethodDef]]. */ - def generateMethodInfo(methodDef: MethodDef): MethodInfo = - new GenInfoTraverser(methodDef.version).generateMethodInfo(methodDef) + def generateMethodInfo(methodDef: MethodDef, linkTimeProperties: LinkTimeProperties): MethodInfo = + new GenInfoTraverser(methodDef.version, linkTimeProperties).generateMethodInfo(methodDef) /** Generates the [[ReachabilityInfo]] of a * [[org.scalajs.ir.Trees.JSConstructorDef Trees.JSConstructorDef]]. */ - def generateJSConstructorInfo(ctorDef: JSConstructorDef): ReachabilityInfo = - new GenInfoTraverser(ctorDef.version).generateJSConstructorInfo(ctorDef) + def generateJSConstructorInfo(ctorDef: JSConstructorDef, linkTimeProperties: LinkTimeProperties): ReachabilityInfo = + new GenInfoTraverser(ctorDef.version, linkTimeProperties).generateJSConstructorInfo(ctorDef) /** Generates the [[ReachabilityInfo]] of a * [[org.scalajs.ir.Trees.JSMethodDef Trees.JSMethodDef]]. */ - def generateJSMethodInfo(methodDef: JSMethodDef): ReachabilityInfo = - new GenInfoTraverser(methodDef.version).generateJSMethodInfo(methodDef) + def generateJSMethodInfo(methodDef: JSMethodDef, linkTimeProperties: LinkTimeProperties): ReachabilityInfo = + new GenInfoTraverser(methodDef.version, linkTimeProperties).generateJSMethodInfo(methodDef) /** Generates the [[ReachabilityInfo]] of a * [[org.scalajs.ir.Trees.JSPropertyDef Trees.JSPropertyDef]]. */ - def generateJSPropertyInfo(propertyDef: JSPropertyDef): ReachabilityInfo = - new GenInfoTraverser(propertyDef.version).generateJSPropertyInfo(propertyDef) + def generateJSPropertyInfo(propertyDef: JSPropertyDef, linkTimeProperties: LinkTimeProperties): ReachabilityInfo = + new GenInfoTraverser(propertyDef.version, linkTimeProperties).generateJSPropertyInfo(propertyDef) - def generateJSMethodPropDefInfo(member: JSMethodPropDef): ReachabilityInfo = member match { - case methodDef: JSMethodDef => generateJSMethodInfo(methodDef) - case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef) - } + def generateJSMethodPropDefInfo(member: JSMethodPropDef, linkTimeProperties: LinkTimeProperties): ReachabilityInfo = + member match { + case methodDef: JSMethodDef => generateJSMethodInfo(methodDef, linkTimeProperties) + case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef, linkTimeProperties) + } /** Generates the [[MethodInfo]] for the top-level exports. */ def generateTopLevelExportInfo(enclosingClass: ClassName, - topLevelExportDef: TopLevelExportDef): TopLevelExportInfo = { - val info = new GenInfoTraverser(Version.Unversioned) + topLevelExportDef: TopLevelExportDef, linkTimeProperties: LinkTimeProperties): TopLevelExportInfo = { + val info = new GenInfoTraverser(Version.Unversioned, linkTimeProperties) .generateTopLevelExportInfo(enclosingClass, topLevelExportDef) new TopLevelExportInfo(info, ModuleID(topLevelExportDef.moduleID), topLevelExportDef.topLevelExportName) } - private final class GenInfoTraverser(version: Version) extends Traverser { + private final class GenInfoTraverser(version: Version, linkTimeProperties: LinkTimeProperties) extends Traverser { private val builder = new ReachabilityInfoBuilder(version) def generateMethodInfo(methodDef: MethodDef): MethodInfo = { @@ -605,6 +607,12 @@ object Infos { } traverse(rhs) + case LinkTimeIf(cond, thenp, elsep) => + if (linkTimeProperties.evaluateLinkTimeTree(cond)) + traverse(thenp) + else + traverse(elsep) + // In all other cases, we'll have to call super.traverse() case _ => tree match { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala index 6351e43614..b377c5b606 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala @@ -40,6 +40,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { import sjsGen._ import jsGen._ import config._ + import coreSpec._ import nameGen._ import varGen._ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala index 2c27c125b9..085d3b005f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala @@ -63,6 +63,7 @@ private[emitter] object CoreJSLib { import sjsGen._ import jsGen._ import config._ + import coreSpec._ import nameGen._ import varGen._ import esFeatures._ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index 07e4dee5f8..debb5eed2d 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -37,6 +37,7 @@ final class Emitter(config: Emitter.Config, prePrinter: Emitter.PrePrinter) { import Emitter._ import config._ + import coreSpec._ require(!config.minify || prePrinter == PrePrinter.Off, "When using the 'minify' option, the prePrinter must be Off.") @@ -1088,23 +1089,16 @@ object Emitter { /** Configuration for the Emitter. */ final class Config private ( - val semantics: Semantics, - val moduleKind: ModuleKind, - val esFeatures: ESFeatures, + val coreSpec: CoreSpec, val jsHeader: String, val internalModulePattern: ModuleID => String, val optimizeBracketSelects: Boolean, val trackAllGlobalRefs: Boolean, val minify: Boolean ) { - private def this( - semantics: Semantics, - moduleKind: ModuleKind, - esFeatures: ESFeatures) = { + private def this(coreSpec: CoreSpec) = { this( - semantics, - moduleKind, - esFeatures, + coreSpec, jsHeader = "", internalModulePattern = "./" + _.id, optimizeBracketSelects = true, @@ -1113,18 +1107,17 @@ object Emitter { ) } + // val semantics = coreSpec.semantics + // val moduleKind = coreSpec.moduleKind + // val esFeatures = coreSpec.esFeatures + // val linkTimeProperties = coreSpec.linkTimeProperties + private[emitter] val topLevelGlobalRefTracking: GlobalRefTracking = if (trackAllGlobalRefs) GlobalRefTracking.All else GlobalRefTracking.Dangerous - def withSemantics(f: Semantics => Semantics): Config = - copy(semantics = f(semantics)) - - def withModuleKind(moduleKind: ModuleKind): Config = - copy(moduleKind = moduleKind) - - def withESFeatures(f: ESFeatures => ESFeatures): Config = - copy(esFeatures = f(esFeatures)) + def withCoreSpec(coreSpec: CoreSpec): Config = + copy(coreSpec = coreSpec) def withJSHeader(jsHeader: String): Config = { require(StandardConfig.isValidJSHeader(jsHeader), jsHeader) @@ -1144,24 +1137,21 @@ object Emitter { copy(minify = minify) private def copy( - semantics: Semantics = semantics, - moduleKind: ModuleKind = moduleKind, - esFeatures: ESFeatures = esFeatures, + coreSpec: CoreSpec = coreSpec, jsHeader: String = jsHeader, internalModulePattern: ModuleID => String = internalModulePattern, optimizeBracketSelects: Boolean = optimizeBracketSelects, trackAllGlobalRefs: Boolean = trackAllGlobalRefs, minify: Boolean = minify ): Config = { - new Config(semantics, moduleKind, esFeatures, jsHeader, - internalModulePattern, optimizeBracketSelects, trackAllGlobalRefs, - minify) + new Config(coreSpec, jsHeader, internalModulePattern, + optimizeBracketSelects, trackAllGlobalRefs, minify) } } object Config { def apply(coreSpec: CoreSpec): Config = - new Config(coreSpec.semantics, coreSpec.moduleKind, coreSpec.esFeatures) + new Config(coreSpec) } sealed trait PrePrinter { @@ -1257,7 +1247,7 @@ object Emitter { ancestors: List[ClassName], moduleContext: ModuleContext) private def symbolRequirements(config: Config): SymbolRequirement = { - import config.semantics._ + import config.coreSpec.semantics._ import CheckedBehavior._ val factory = SymbolRequirement.factory("emitter") @@ -1313,7 +1303,7 @@ object Emitter { callMethod(BoxedDoubleClass, hashCodeMethodName), callMethod(BoxedStringClass, hashCodeMethodName), - cond(!config.esFeatures.allowBigIntsForLongs) { + cond(!config.coreSpec.esFeatures.allowBigIntsForLongs) { multiple( instanceTests(LongImpl.RuntimeLongClass), instantiateClass(LongImpl.RuntimeLongClass, LongImpl.AllConstructors.toList), diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index 953a54241f..9f1a852eaf 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -251,6 +251,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { import sjsGen._ import jsGen._ import config._ + import coreSpec._ import nameGen._ import varGen._ @@ -1283,6 +1284,10 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case IdentityHashCode(expr) => test(expr) case GetClass(arg) => testNPE(arg) + case LinkTimeIf(cond, thenp, elsep) => + if (linkTimeProperties.evaluateLinkTimeTree(cond)) test(thenp) + else test(elsep) + // Expressions preserving pureness (modulo NPE) but requiring that expr be a var case WrapAsThrowable(expr @ (VarRef(_) | Transient(JSVarRef(_, _)))) => test(expr) case UnwrapFromThrowable(expr @ (VarRef(_) | Transient(JSVarRef(_, _)))) => testNPE(expr) @@ -2204,6 +2209,12 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { js.If(transformExprNoChar(cond), transformExpr(thenp, tree.tpe), transformExpr(elsep, tree.tpe)) + case LinkTimeIf(cond, thenp, elsep) => + if (linkTimeProperties.evaluateLinkTimeTree(cond)) + transformExpr(thenp, tree.tpe) + else + transformExpr(elsep, tree.tpe) + // Scala expressions case New(className, ctor, args) => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/JSGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/JSGen.scala index 4da09323c5..df5aa9fa41 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/JSGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/JSGen.scala @@ -25,6 +25,7 @@ import org.scalajs.linker.interface.ESVersion private[emitter] final class JSGen(val config: Emitter.Config) { import config._ + import coreSpec._ /** Should we use ECMAScript classes for JavaScript classes and Throwable * classes? diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/KnowledgeGuardian.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/KnowledgeGuardian.scala index 562ad1e519..3097a9aabf 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/KnowledgeGuardian.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/KnowledgeGuardian.scala @@ -146,7 +146,7 @@ private[emitter] final class KnowledgeGuardian(config: Emitter.Config) { private def computeStaticFieldMirrors( moduleSet: ModuleSet): Map[ClassName, Map[FieldName, List[String]]] = { - if (config.moduleKind != ModuleKind.NoModule) { + if (config.coreSpec.moduleKind != ModuleKind.NoModule) { Map.empty } else { var result = Map.empty[ClassName, Map[FieldName, List[String]]] diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala index 2a9b3cf93e..2f2f112ca8 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala @@ -38,6 +38,7 @@ private[emitter] final class SJSGen( import jsGen._ import config._ + import coreSpec._ import nameGen._ import varGen._ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/VarGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/VarGen.scala index 11d93244d9..ba20e783b1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/VarGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/VarGen.scala @@ -97,7 +97,7 @@ private[emitter] final class VarGen(jsGen: JSGen, nameGen: NameGen, val ident = globalVarIdent(field, scope, origName) val varDef = genLet(ident, mutable = true, value) - if (config.moduleKind == ModuleKind.ESModule && !moduleContext.public) { + if (config.coreSpec.moduleKind == ModuleKind.ESModule && !moduleContext.public) { val setterIdent = globalVarIdent(setterField, scope) val x = Ident("x") val setter = FunctionDef(setterIdent, List(ParamDef(x)), None, { @@ -117,7 +117,7 @@ private[emitter] final class VarGen(jsGen: JSGen, nameGen: NameGen, def needToUseGloballyMutableVarSetter[T](scope: T)( implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge, scopeType: Scope[T]): Boolean = { - config.moduleKind == ModuleKind.ESModule && + config.coreSpec.moduleKind == ModuleKind.ESModule && globalKnowledge.getModule(scopeType.reprClass(scope)) != moduleContext.moduleID } @@ -125,7 +125,7 @@ private[emitter] final class VarGen(jsGen: JSGen, nameGen: NameGen, origName: OriginalName = NoOriginalName)( implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge, pos: Position): Tree = { - assert(config.moduleKind == ModuleKind.ESModule) + assert(config.coreSpec.moduleKind == ModuleKind.ESModule) val ident = globalVarIdent(field, scope, origName) foldSameModule[T, Tree](scope) { @@ -163,7 +163,7 @@ private[emitter] final class VarGen(jsGen: JSGen, nameGen: NameGen, } { moduleID => val moduleName = config.internalModulePattern(moduleID) - val moduleTree = config.moduleKind match { + val moduleTree = config.coreSpec.moduleKind match { case ModuleKind.NoModule => /* If we get here, it means that what we are trying to import is in a * different module than the module we're currently generating @@ -279,7 +279,7 @@ private[emitter] final class VarGen(jsGen: JSGen, nameGen: NameGen, if (moduleContext.public) { WithGlobals(tree :: Nil) } else { - val exportStat = config.moduleKind match { + val exportStat = config.coreSpec.moduleKind match { case ModuleKind.NoModule => throw new AssertionError("non-public module in NoModule mode") diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index 0a477f6a59..41d40df4bf 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -77,7 +77,7 @@ final class Emitter(config: Emitter.Config) { val moduleInitializers = module.initializers.toList implicit val ctx: WasmContext = - Preprocessor.preprocess(sortedClasses, topLevelExports) + Preprocessor.preprocess(sortedClasses, module.topLevelExports, config.coreSpec) CoreWasmLib.genPreClasses() genExternalModuleImports(module) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 5fecd5dbd2..66754e3e31 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -445,6 +445,8 @@ private class FunctionEmitter private ( case t: JSSuperMethodCall => genJSSuperMethodCall(t) case t: JSNewTarget => genJSNewTarget(t) + case t: LinkTimeIf => genLinkTimeIf(t, expectedType) + // Records (only generated by the optimizer) case t: RecordSelect => genRecordSelect(t) case t: RecordValue => genRecordValue(t) @@ -3004,6 +3006,14 @@ private class FunctionEmitter private ( NoType } + private def genLinkTimeIf(tree: LinkTimeIf, expectedType: Type): Type = { + if (ctx.coreSpec.linkTimeProperties.evaluateLinkTimeTree(tree.cond)) + genTree(tree.thenp, expectedType) + else + genTree(tree.elsep, expectedType) + expectedType + } + /*--------------------------------------------------------------------* * HERE BE DRAGONS --- Handling of TryFinally, Labeled and Return --- * *--------------------------------------------------------------------*/ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 1aa1ea6f2f..6a5ec1baae 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -19,13 +19,13 @@ import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ import org.scalajs.ir.{ClassKind, Traversers} -import org.scalajs.linker.standard.{LinkedClass, LinkedTopLevelExport} +import org.scalajs.linker.standard.{LinkedClass, LinkedTopLevelExport, CoreSpec} import EmbeddedConstants._ import WasmContext._ object Preprocessor { - def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport]): WasmContext = { + def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport], coreSpec: CoreSpec): WasmContext = { val staticFieldMirrors = computeStaticFieldMirrors(tles) val specialInstanceTypes = computeSpecialInstanceTypes(classes) @@ -62,7 +62,7 @@ object Preprocessor { // sort for stability val reflectiveProxyIDs = definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap - new WasmContext(classInfos, reflectiveProxyIDs, itableBucketCount) + new WasmContext(classInfos, reflectiveProxyIDs, itableBucketCount, coreSpec) } private def computeStaticFieldMirrors( diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index aa230752a3..aac809ae3f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -27,6 +27,7 @@ import org.scalajs.linker.interface.ModuleInitializer import org.scalajs.linker.interface.unstable.ModuleInitializerImpl import org.scalajs.linker.standard.LinkedTopLevelExport import org.scalajs.linker.standard.LinkedClass +import org.scalajs.linker.standard.CoreSpec import org.scalajs.linker.backend.webassembly.ModuleBuilder import org.scalajs.linker.backend.webassembly.{Instructions => wa} @@ -40,7 +41,8 @@ import org.scalajs.ir.OriginalName final class WasmContext( classInfo: Map[ClassName, WasmContext.ClassInfo], reflectiveProxies: Map[MethodName, Int], - val itablesLength: Int + val itablesLength: Int, + val coreSpec: CoreSpec ) { import WasmContext._ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index 5da8ba0a6f..6f2831ffaa 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -25,10 +25,12 @@ import org.scalajs.logging._ import org.scalajs.linker.checker.ErrorReporter._ import org.scalajs.linker.standard.LinkedClass +import org.scalajs.linker.standard.LinkTimeProperties /** Checker for the validity of the IR. */ private final class ClassDefChecker(classDef: ClassDef, - postBaseLinker: Boolean, postOptimizer: Boolean, reporter: ErrorReporter) { + postBaseLinker: Boolean, postOptimizer: Boolean, + reporter: ErrorReporter, linkTimeProperties: LinkTimeProperties) { import ClassDefChecker._ import reporter.reportError @@ -877,6 +879,13 @@ private final class ClassDefChecker(classDef: ClassDef, transient.traverse(new Traversers.Traverser { override def traverse(tree: Tree): Unit = checkTree(tree, env) }) + + case LinkTimeIf(cond, thenp, elsep) => + if (cond.tpe != BooleanType) + reportError(i"Link-time condition must be typed as boolean, but ${cond.tpe} is found.") + checkLinkTimeTree(cond) + checkTree(thenp, env) + checkTree(elsep, env) } newEnv @@ -920,6 +929,31 @@ private final class ClassDefChecker(classDef: ClassDef, if (!declaredLabelNamesPerMethod.add(label.name)) reportError(i"Duplicate label named ${label.name}.") } + + private def checkLinkTimeTree(tree: LinkTimeTree): Unit = { + implicit val ctx = ErrorContext(tree) + import LinkTimeOp._ + tree match { + case LinkTimeTree.BinaryOp(op, lhs, rhs) => + if (lhs.tpe != rhs.tpe) + reportError(i"Type mismatch for binary operation: ${lhs.tpe} and ${rhs.tpe}.") + op match { + case Boolean_!= | Boolean_== | Boolean_&& | Boolean_|| => + if (lhs.tpe != BooleanType) + reportError(i"Invalid operand type for Boolean operation: ${lhs.tpe}.") + case Int_!= | Int_== | Int_< | Int_<= | Int_> | Int_>= => + if (lhs.tpe != IntType) + reportError(i"Invalid operand type for Integer operation: ${lhs.tpe}.") + } + checkLinkTimeTree(lhs) + checkLinkTimeTree(rhs) + case prop: LinkTimeTree.Property => + if (!linkTimeProperties.exist(prop.name, prop.tpe)) { + reportError(i"link-time property '${prop.name}' of ${prop.tpe} not found.") + } + case _ => + } + } } object ClassDefChecker { @@ -927,13 +961,16 @@ object ClassDefChecker { * * @return Count of IR checking errors (0 in case of success) */ - def check(classDef: ClassDef, postBaseLinker: Boolean, postOptimizer: Boolean, logger: Logger): Int = { + def check(classDef: ClassDef, postBaseLinker: Boolean, postOptimizer: Boolean, + logger: Logger, linkTimeProperties: LinkTimeProperties): Int = { val reporter = new LoggerErrorReporter(logger) - new ClassDefChecker(classDef, postBaseLinker, postOptimizer, reporter).checkClassDef() + new ClassDefChecker(classDef, postBaseLinker, postOptimizer, reporter, + linkTimeProperties).checkClassDef() reporter.errorCount } - def check(linkedClass: LinkedClass, postOptimizer: Boolean, logger: Logger): Int = { + def check(linkedClass: LinkedClass, postOptimizer: Boolean, + logger: Logger, linkTimeProperties: LinkTimeProperties): Int = { // Rebuild a ClassDef out of the LinkedClass import linkedClass._ implicit val pos = linkedClass.pos @@ -954,7 +991,7 @@ object ClassDefChecker { topLevelExportDefs = Nil )(optimizerHints) - check(classDef, postBaseLinker = true, postOptimizer, logger) + check(classDef, postBaseLinker = true, postOptimizer, logger, linkTimeProperties) } private class Env( diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index eb12a06b3a..2e7aad0578 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -25,10 +25,12 @@ import org.scalajs.logging._ import org.scalajs.linker.frontend.LinkingUnit import org.scalajs.linker.standard.LinkedClass +import org.scalajs.linker.standard.LinkTimeProperties import org.scalajs.linker.checker.ErrorReporter._ /** Checker for the validity of the IR. */ -private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { +private final class IRChecker(unit: LinkingUnit, + linkTimeProperties: LinkTimeProperties, reporter: ErrorReporter) { import IRChecker._ import reporter.reportError @@ -268,6 +270,13 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { typecheckExpect(thenp, env, tpe) typecheckExpect(elsep, env, tpe) + case LinkTimeIf(cond, thenp, elsep) => + val tpe = tree.tpe + if (linkTimeProperties.evaluateLinkTimeTree(cond)) + typecheckExpect(thenp, env, tpe) + else + typecheckExpect(elsep, env, tpe) + case While(cond, body) => typecheckExpect(cond, env, BooleanType) typecheck(body, env) @@ -837,9 +846,10 @@ object IRChecker { * * @return Count of IR checking errors (0 in case of success) */ - def check(unit: LinkingUnit, logger: Logger): Int = { + def check(unit: LinkingUnit, + linkTimeProperties: LinkTimeProperties, logger: Logger): Int = { val reporter = new LoggerErrorReporter(logger) - new IRChecker(unit, reporter).check() + new IRChecker(unit, linkTimeProperties, reporter).check() reporter.errorCount } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala index f120dff28a..1eb004097c 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala @@ -56,7 +56,8 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { } yield { if (checkIR) { logger.time("Linker: Check IR") { - val errorCount = IRChecker.check(linkResult, logger) + val errorCount = IRChecker.check( + linkResult, config.coreSpec.linkTimeProperties, logger) if (errorCount != 0) { throw new LinkingException( s"There were $errorCount IR checking errors.") diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 59663392bf..6d6a421718 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -692,6 +692,12 @@ private[optimizer] abstract class OptimizerCore( case LoadJSConstructor(className) => transformJSLoadCommon(ImportTarget.Class(className), tree) + case LinkTimeIf(cond, thenp, elsep) => + if (config.coreSpec.linkTimeProperties.evaluateLinkTimeTree(cond)) + transform(thenp, isStat) + else + transform(elsep, isStat) + // Trees that need not be transformed case _:Skip | _:Debugger | _:StoreModule | diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala index 3c4c979adc..69f2ab18e3 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala @@ -53,6 +53,10 @@ final class CoreSpec private ( def withTargetIsWebAssembly(targetIsWebAssembly: Boolean): CoreSpec = copy(targetIsWebAssembly = targetIsWebAssembly) + /** Link-time resolved properties */ + + val linkTimeProperties = new LinkTimeProperties( + semantics, esFeatures) override def equals(that: Any): Boolean = that match { case that: CoreSpec => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala new file mode 100644 index 0000000000..c7f0cbf3ca --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala @@ -0,0 +1,99 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.standard + +import org.scalajs.ir.{Types => jstpe} +import org.scalajs.ir.Trees.LinkTimeTree +import org.scalajs.ir.Trees.LinkTimeOp +import org.scalajs.linker.interface.{Semantics, ESFeatures} + +final class LinkTimeProperties private[standard] ( + semantics: Semantics, + esFeatures: ESFeatures +) { + import LinkTimeProperties._ + import LinkTimeProperties.ResolvedLinkTimeTree._ + + private val linkTimeProperties: Map[String, ResolvedLinkTimeTree] = Map( + // Must be in sync with the arguments of @linkTimeProperty("...") + // for the fields in `scala.scalajs.LinkingInfo`. + "core/productionMode" -> BooleanValue(semantics.productionMode), + "core/esVersion" -> IntValue(esFeatures.esVersion.edition) + ) + + def evaluateLinkTimeTree(cond: LinkTimeTree): Boolean = { + eval(cond) match { + case BooleanValue(v) => v + case IntValue(v) => + throw new IllegalArgumentException( + "Link-time condition must be evaluated to be a boolean value, but int is found.") + } + } + + def exist(name: String, tpe: jstpe.Type): Boolean = + linkTimeProperties.get(name).exists { + case IntValue(_) => tpe == jstpe.IntType + case BooleanValue(_) => tpe == jstpe.BooleanType + } + + private def eval(cond: LinkTimeTree): ResolvedLinkTimeTree = cond match { + case LinkTimeTree.BinaryOp(op, lhs, rhs) => + ResolvedLinkTimeTree.BooleanValue { + (eval(lhs), eval(rhs)) match { + case (IntValue(l), IntValue(r)) => + op match { + case LinkTimeOp.Int_== => l == r + case LinkTimeOp.Int_!= => l != r + case LinkTimeOp.Int_< => l < r + case LinkTimeOp.Int_<= => l <= r + case LinkTimeOp.Int_> => l > r + case LinkTimeOp.Int_>= => l >= r + case _ => + throw new IllegalArgumentException(s"Invalid operation $op for int values.") + } + case (BooleanValue(l), BooleanValue(r)) => + op match { + case LinkTimeOp.Boolean_== => l == r + case LinkTimeOp.Boolean_!= => l != r + case LinkTimeOp.Boolean_|| => l || r + case LinkTimeOp.Boolean_&& => l && r + case _ => + throw new IllegalArgumentException(s"Invalid operation $op for boolean values.") + } + case _ => + throw new IllegalArgumentException("Type mismatch: binary operation with different types " + + "is not allowed in linkTimeIf.") + } + } + case LinkTimeTree.BooleanConst(v) => BooleanValue(v) + case LinkTimeTree.IntConst(v) => IntValue(v) + case LinkTimeTree.Property(name, _) => resolveLinkTimeProperty(name) + } + + private def resolveLinkTimeProperty(prop: String): ResolvedLinkTimeTree = + linkTimeProperties.getOrElse(prop, throw new IllegalArgumentException(s"link time property not found: '$prop'")) +} + +object LinkTimeProperties { + private sealed abstract class ResolvedLinkTimeTree + private object ResolvedLinkTimeTree { + case class IntValue(v: Int) extends ResolvedLinkTimeTree + case class BooleanValue(v: Boolean) extends ResolvedLinkTimeTree + } + + private[linker] val Defaults: LinkTimeProperties = + new LinkTimeProperties( + Semantics.Defaults, + ESFeatures.Defaults + ) +} diff --git a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala index f797ad25a1..96437fc964 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala @@ -806,6 +806,63 @@ class AnalyzerTest { assertFalse(I2barMethodInfo.isAbstractReachable) } } + + @Test + def linkTimeIfReachable(): AsyncResult = await { + val mainMethodName = m("main", Nil, IntRef) + val fooMethodName = m("foo", Nil, IntRef) + val barMethodName = m("bar", Nil, IntRef) + + val productionMode = true + + /* linkTimeIf(productionMode) { + * this.foo() + * } { + * this.bar() + * } + */ + val mainBody = LinkTimeIf( + LinkTimeTree.BinaryOp( + LinkTimeOp.Boolean_==, + LinkTimeTree.Property( + "core/productionMode", BooleanType), + LinkTimeTree.BooleanConst(productionMode)), + Apply(EAF, This()(ClassType("A")), fooMethodName, Nil)(IntType), + Apply(EAF, This()(ClassType("A")), barMethodName, Nil)(IntType) + )(IntType) + + val classDefs = Seq( + classDef("A", superClass = Some(ObjectClass), + methods = List( + trivialCtor("A"), + MethodDef(EMF, mainMethodName, NON, Nil, IntType, + Some(mainBody))(EOH, UNV), + MethodDef(EMF, fooMethodName, NON, Nil, IntType, + Some(Null()))(EOH, UNV), + MethodDef(EMF, barMethodName, NON, Nil, IntType, + Some(Null()))(EOH, UNV) + ) + ) + ) + + val analysisFuture = computeAnalysis(classDefs, + reqsFactory.instantiateClass("A", NoArgConstructorName) ++ + reqsFactory.callMethod("A", mainMethodName), + config = StandardConfig().withSemantics(_.withProductionMode(productionMode)) + ) + + for (analysis <- analysisFuture) yield { + assertNoError(analysis) + + val AfooMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(fooMethodName) + assertTrue(AfooMethodInfo.isReachable) + + val AbarMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(barMethodName) + assertFalse(AbarMethodInfo.isReachable) + } + } } object AnalyzerTest { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/BaseLinkerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/BaseLinkerTest.scala index 2f4c74f00a..89c1383847 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/BaseLinkerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/BaseLinkerTest.scala @@ -30,6 +30,7 @@ import org.scalajs.linker.standard._ import org.scalajs.linker.testutils.TestIRBuilder._ import org.scalajs.linker.testutils.LinkingUtils._ +import org.scalajs.linker.interface.{Semantics, ESFeatures} class BaseLinkerTest { import scala.concurrent.ExecutionContext.Implicits.global @@ -87,7 +88,7 @@ class BaseLinkerTest { for (moduleSet <- linkToModuleSet(classDefs, MainTestModuleInitializers, config = config)) yield { val clazz = findClass(moduleSet, BoxedIntegerClass).get val errorCount = ClassDefChecker.check(clazz, postOptimizer = false, - new ScalaConsoleLogger(Level.Error)) + new ScalaConsoleLogger(Level.Error), LinkTimeProperties.Defaults) assertEquals(0, errorCount) } } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala index d81ce35df5..f7b9491958 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala @@ -565,6 +565,74 @@ class OptimizerTest { } } + @Test + def removeUnreachableLinkTimeIfBranch(): AsyncResult = await { + val methodName = m("method", Nil, I) + val methodBody = LinkTimeIf( + LinkTimeTree.BinaryOp( + LinkTimeOp.Boolean_==, + LinkTimeTree.Property("core/productionMode", BooleanType), + LinkTimeTree.BooleanConst(true)), + int(1), int(0))(IntType) + val classDefs = Seq( + classDef("Foo", kind = ClassKind.Class, superClass = Some(ObjectClass), + methods = List( + trivialCtor("Foo"), + MethodDef(EMF, methodName, NON, Nil, IntType, Some(methodBody))(EOH, UNV) + )), + mainTestClassDef({ + consoleLog(Apply(EAF, New("Foo", NoArgConstructorName, Nil), methodName, Nil)(IntType)) + }) + ) + for { + moduleSet <- linkToModuleSet( + classDefs, MainTestModuleInitializers, + config = StandardConfig().withSemantics((_.withProductionMode(true))) + ) + } yield { + findClass(moduleSet, ClassName("Foo")).get + .methods.find(_.name.name == methodName).get + .body.get match { + case IntLiteral(1) => // ok + case t => + fail(s"Unexpected body: $t") + } + } + } + + @Test + def removeUnreachableCalleeByLinkTimeIf(): AsyncResult = await { + val methodName = m("method", Nil, I) + val classDefs = Seq( + classDef("Foo", kind = ClassKind.Class, superClass = Some(ObjectClass), + methods = List( + trivialCtor("Foo"), + // def method(): Int = 0 + MethodDef(EMF, methodName, NON, Nil, IntType, Some(int(0)))(EOH, UNV)) + ), + mainTestClassDef({ + LinkTimeIf( + LinkTimeTree.BinaryOp( + LinkTimeOp.Boolean_==, + LinkTimeTree.Property("core/productionMode", BooleanType), + LinkTimeTree.BooleanConst(true)), + consoleLog(str("prod")), + consoleLog(Apply(EAF, New("Foo", NoArgConstructorName, Nil), + methodName, Nil)(IntType)) + )(NoType) + }) + ) + + for { + moduleSet <- linkToModuleSet( + classDefs, MainTestModuleInitializers, + config = StandardConfig().withSemantics((_.withProductionMode(true))) + ) + } yield { + assertFalse(findClass(moduleSet, ClassName("Foo")).isDefined) + } + } + def inlineFlagsTestCommon(optimizerHints: OptimizerHints, applyFlags: ApplyFlags, expectInline: Boolean): AsyncResult = await { val classDefs = Seq( diff --git a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala index 62fbd15bff..5c70f153e4 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala @@ -23,8 +23,8 @@ import org.scalajs.ir.Types._ import org.scalajs.logging.NullLogger -import org.scalajs.linker.interface.{LinkingException, StandardConfig} -import org.scalajs.linker.standard.{StandardLinkerFrontend, SymbolRequirement} +import org.scalajs.linker.interface.{LinkingException, StandardConfig, Semantics, ESFeatures} +import org.scalajs.linker.standard.{StandardLinkerFrontend, SymbolRequirement, LinkTimeProperties} import org.scalajs.linker.testutils._ import org.scalajs.linker.testutils.TestIRBuilder._ @@ -457,11 +457,73 @@ class ClassDefCheckerTest { testIsInstanceOfError(ArrayType(ArrayTypeRef(IntRef, 1), nullable = true)) testAsInstanceOfError(ArrayType(ArrayTypeRef(IntRef, 1), nullable = false)) } + + @Test + def linkTimeIfTest(): Unit = { + def linkTimeIf(cond: LinkTimeTree): ClassDef = { + classDef( + "Foo", + kind = ClassKind.ModuleClass, + superClass = Some(ObjectClass), + methods = List( + trivialCtor("Foo"), + MethodDef(EMF, MethodName("foo", Nil, VoidRef), NON, Nil, NoType, Some { + LinkTimeIf( + cond, + consoleLog(StringLiteral("foo")), + consoleLog(StringLiteral("bar")) + )(NoType) + })(EOH, UNV) + ) + ) + } + + assertError( + linkTimeIf(LinkTimeTree.Property("core/esVersion", IntType)), + "Link-time condition must be typed as boolean, but int is found." + ) + + assertError( + linkTimeIf( + LinkTimeTree.BinaryOp(LinkTimeOp.Int_==, + LinkTimeTree.IntConst(0), LinkTimeTree.Property("core/productionMode", BooleanType)) + ), + "Type mismatch for binary operation: int and boolean." + ) + + assertError( + linkTimeIf( + LinkTimeTree.BinaryOp(LinkTimeOp.Int_==, + LinkTimeTree.IntConst(0), LinkTimeTree.Property("core/productionMode", BooleanType)) + ), + "Type mismatch for binary operation: int and boolean." + ) + + assertError( + linkTimeIf( + LinkTimeTree.BinaryOp(LinkTimeOp.Boolean_==, + LinkTimeTree.IntConst(0), LinkTimeTree.Property("core/esVersion", IntType)) + ), + "Invalid operand type for Boolean operation: int." + ) + + assertError( + linkTimeIf(LinkTimeTree.Property("prop-not-found", BooleanType)), + "link-time property 'prop-not-found' of boolean not found." + ) + + assertError( + linkTimeIf(LinkTimeTree.Property("core/esVersion", BooleanType)), + "link-time property 'core/esVersion' of boolean not found." + ) + } } private object ClassDefCheckerTest { private def assertError(clazz: ClassDef, expectMsg: String, - allowReflectiveProxies: Boolean = false, allowTransients: Boolean = false) = { + allowReflectiveProxies: Boolean = false, allowTransients: Boolean = false, + linkTimeProperties: LinkTimeProperties = LinkTimeProperties.Defaults + ) = { var seen = false val reporter = new ErrorReporter { def reportError(msg: String)(implicit ctx: ErrorReporter.ErrorContext) = { @@ -471,7 +533,7 @@ private object ClassDefCheckerTest { } } - new ClassDefChecker(clazz, allowReflectiveProxies, allowTransients, reporter).checkClassDef() + new ClassDefChecker(clazz, allowReflectiveProxies, allowTransients, reporter, linkTimeProperties).checkClassDef() assertTrue("no errors reported", seen) } } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/standard/LinkTimePropertiesTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/standard/LinkTimePropertiesTest.scala new file mode 100644 index 0000000000..b7237e4a85 --- /dev/null +++ b/linker/shared/src/test/scala/org/scalajs/linker/standard/LinkTimePropertiesTest.scala @@ -0,0 +1,128 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.standard + +import org.junit.Test +import org.junit.Assert._ + +import org.scalajs.ir.{Trees => js, Types => jstpe} +import org.scalajs.ir.Position +import org.scalajs.linker.interface.{Semantics, ESFeatures, ESVersion} + +import js.LinkTimeOp._ +import js.LinkTimeTree._ + +class LinkTimePropertiesTest { + private def evalTrue(prop: LinkTimeProperties, cond: js.LinkTimeTree): Unit = + assertTrue(prop.evaluateLinkTimeTree(cond)) + private def evalFalse(prop: LinkTimeProperties, cond: js.LinkTimeTree): Unit = + assertFalse(prop.evaluateLinkTimeTree(cond)) + + private def i(x: Int) = js.LinkTimeTree.IntConst(x) + private def b(x: Boolean) = js.LinkTimeTree.BooleanConst(x) + private def p(x: String, t: jstpe.Type) = js.LinkTimeTree.Property(x, t) + + private val productionMode = + p("core/productionMode", jstpe.BooleanType) + private val esVersion = + p("core/esVersion", jstpe.IntType) + + private implicit val noPos: Position = Position.NoPosition + + @Test + def evaluateLinkTimeTreeConst(): Unit = { + val prop = new LinkTimeProperties( + Semantics.Defaults, + ESFeatures.Defaults + ) + + evalTrue(prop, b(true)) + evalFalse(prop, b(false)) + + evalTrue(prop, BinaryOp(Boolean_==, b(true), b(true))) + evalTrue(prop, BinaryOp(Boolean_!=, b(true), b(false))) + evalFalse(prop, BinaryOp(Boolean_==, b(true), b(false))) + evalFalse(prop, BinaryOp(Boolean_!=, b(true), b(true))) + + evalTrue(prop, BinaryOp(Int_!=, i(0), i(3))) + evalTrue(prop, BinaryOp(Int_==, i(0), i(0))) + evalTrue(prop, BinaryOp(Int_>, i(1), i(0))) + evalTrue(prop, BinaryOp(Int_>=, i(0), i(0))) + evalTrue(prop, BinaryOp(Int_<, i(0), i(1))) + evalTrue(prop, BinaryOp(Int_<=, i(0), i(0))) + } + + @Test + def resolveLinkTimeProperty(): Unit = { + val sem = Semantics.Defaults.withProductionMode(true) + val esFeatures = ESFeatures.Defaults.withESVersion(ESVersion.ES2015) + val prop = new LinkTimeProperties(sem, esFeatures) + + evalTrue(prop, productionMode) + evalTrue(prop, BinaryOp(Boolean_==, productionMode, b(sem.productionMode))) + evalTrue(prop, BinaryOp(Boolean_!=, b(!sem.productionMode), productionMode)) + evalTrue(prop, BinaryOp(Int_==, i(esFeatures.esVersion.edition), esVersion)) + evalTrue(prop, BinaryOp(Int_>, i(ESVersion.ES2016.edition), esVersion)) + } + + @Test + def linkTimeConditionNested(): Unit = { + val sem = Semantics.Defaults.withProductionMode(true) + val esFeatures = ESFeatures.Defaults.withESVersion(ESVersion.ES2015) + val prop = new LinkTimeProperties(sem, esFeatures) + + // esVersion >= ESVersion.ES2015 && esVersion <= ESVersion.ES2019 + evalTrue(prop, + BinaryOp( + Boolean_&&, + BinaryOp(Int_>=, esVersion, i(ESVersion.ES2015.edition)), + BinaryOp(Int_<=, esVersion, i(ESVersion.ES2019.edition)) + ) + ) + + // (esVersion > ESVersion.ES5_1 && esVersion < ESVersion.ES2015) || productionMode + evalTrue(prop, + BinaryOp( + Boolean_||, + BinaryOp(Boolean_&&, + BinaryOp(Int_>, esVersion, i(ESVersion.ES5_1.edition)), + BinaryOp(Int_<, esVersion, i(ESVersion.ES2015.edition)) + ), + productionMode + ) + ) + } + + @Test + def linkTimePropertyNotFound(): Unit = { + val prop = new LinkTimeProperties(Semantics.Defaults, ESFeatures.Defaults) + val tree = BinaryOp( + Boolean_||, + p("prop/notFound", jstpe.BooleanType), + b(true) + ) + assertThrows(classOf[IllegalArgumentException], () => prop.evaluateLinkTimeTree(tree)) + } + + @Test + def linkTimePropertyInvalidInput(): Unit = { + val prop = new LinkTimeProperties(Semantics.Defaults, ESFeatures.Defaults) + def test(tree: js.LinkTimeTree): Unit = + assertThrows(classOf[IllegalArgumentException], () => prop.evaluateLinkTimeTree(tree)) + + test(p("core/esVersion", jstpe.IntType)) + test(BinaryOp(Boolean_||, i(1), b(true))) + test(BinaryOp(Boolean_||, i(1), i(10))) + test(BinaryOp(Int_==, b(true), b(false))) + } +} diff --git a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala new file mode 100644 index 0000000000..189985be35 --- /dev/null +++ b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala @@ -0,0 +1,88 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.testsuite.library + +import scala.scalajs.js +import scala.scalajs.LinkingInfo._ + +import org.junit.Test +import org.junit.Assert._ +import org.junit.Assume._ + +import org.scalajs.testsuite.utils.Platform + +class LinkTimeIfTest { + @Test def linkTimeIfConst(): Unit = { + // boolean const + assertEquals(1, linkTimeIf(true) { 1 } { 2 }) + assertEquals(2, linkTimeIf(false) { 1 } { 2 }) + } + + @Test def linkTimeIfProp(): Unit = { + locally { + val cond = Platform.isInProductionMode + assertEquals(cond, linkTimeIf(productionMode) { true } { false }) + } + + locally { + val cond = !Platform.isInProductionMode + assertEquals(cond, linkTimeIf(!productionMode) { true } { false }) + } + } + + @Test def linkTimIfIntProp(): Unit = { + locally { + val cond = Platform.assumedESVersion >= ESVersion.ES2015 + assertEquals(cond, linkTimeIf(esVersion >= ESVersion.ES2015) { true } { false }) + } + + locally { + val cond = !(Platform.assumedESVersion < ESVersion.ES2015) + assertEquals(cond, linkTimeIf(!(esVersion < ESVersion.ES2015)) { true } { false }) + } + } + + @Test def linkTimeIfNested(): Unit = { + locally { + val cond = + Platform.isInProductionMode && + Platform.assumedESVersion >= ESVersion.ES2015 + assertEquals(cond, + linkTimeIf(productionMode && esVersion >= ESVersion.ES2015) { true } { false }) + } + + locally { + val cond = + Platform.assumedESVersion >= ESVersion.ES2015 && + Platform.assumedESVersion < ESVersion.ES2019 && + Platform.isInProductionMode + assertEquals(cond, + linkTimeIf( + esVersion >= ESVersion.ES2015 && + esVersion < ESVersion.ES2019 && + productionMode + ) { true } { false }) + } + } + + @Test def exponentOp(): Unit = { + def pow(x: Double, y: Double): Double = + linkTimeIf(esVersion >= ESVersion.ES2016) { + (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]) + .asInstanceOf[Double] + } { + Math.pow(x, y) + } + assertEquals(pow(2.0, 8.0), 256.0, 0) + } +}