From c399a080a86503a67dd235a02561feb8b8d96c65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 22 Apr 2025 10:02:38 +0200 Subject: [PATCH 1/3] Bump the version to 1.20.0-SNAPSHOT for the upcoming changes. As well as the IR version to 1.20-SNAPSHOT. --- ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala index 4de34d7f0b..23292cbcdc 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala @@ -17,8 +17,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.util.matching.Regex object ScalaJSVersions extends VersionChecks( - current = "1.19.1-SNAPSHOT", - binaryEmitted = "1.19" + current = "1.20.0-SNAPSHOT", + binaryEmitted = "1.20-SNAPSHOT" ) /** Helper class to allow for testing of logic. */ From 5e842d868dd16c0f15bfbb4c583e6f8eac24c87e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 3 Jan 2025 17:22:15 +0100 Subject: [PATCH 2/3] Fix #4997: Add `linkTimeIf` for link-time conditional branching. Thanks to our optimizer's ability to inline, constant-fold, and then eliminate dead code, we have been able to write link-time conditional branches for a long time. Typical examples include polyfills, as illustrated in the documentation of `LinkingInfo`: if (esVersion >= ESVersion.ES2018 || featureTest()) useES2018Feature() else usePolyfill() which gets folded away to nothing but useES2018Feature() when linking for ES2018+. However, this only works because both branches can *link* during the initial reachability analysis. We cannot use the same technique when one of the branches would refuse to link in the first place. The canonical example is the usage of the JS `**` operator, which does not link below ES2016. The following snippet produces good code when linking for ES2016+, but does not link at all for ES2015: def pow(x: Double, y: Double): Double = { if (esVersion >= ESVersion.ES2016) { (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]) .asInstanceOf[Double] } { Math.pow(x, y) } } --- This commit introduces `LinkingInfo.linkTimeIf`, a conditional branch that is guaranteed by spec to be resolved at link-time. Using a `linkTimeIf` instead of the `if` in `def pow`, we can successfully link the fallback branch on ES2015, because the then branch is not even followed by the reachability analysis. In order to provide that guarantee, the corresponding `LinkTimeIf` IR node has strong requirements on its condition. It must be a "link-time expression", which is guaranteed to be resolved at link-time. A link-time expression tree must be of the form: * A `Literal` (of type `int`, `boolean` or `string`, although `string`s are not actually usable here). * A `LinkTimeProperty`. * One of the boolean operators. * One of the int comparison operators. * A nested `LinkTimeIf` (used to encode short-circuiting boolean `&&` and `||`). The `ClassDefChecker` validates the above property, and ensures that link-time expression trees are *well-typed*. Normally that is the job of the IR checker. Here we *can* do in `ClassDefChecker` because we only have the 3 primitive types to deal with; and we *must* do it then, because the reachability analysis itself is only sound if all link-time expression trees are well-typed. The reachability analysis algorithm itself is not affected by `LinkTimeIf`. Instead, we resolve link-time branches when building the `Infos` of methods. We follow only the branch that is taken. This means that `Infos` builders now require access to the `linkTimeProperties` derived from the `coreSpec`, but that is the only additional piece of complexity in that area. `LinkTimeIf`s nodes are later removed from the trees during desugaring. --- At the language and compiler level, we introduce `LinkingInfo.linkTimeIf` as a primitive for `LinkTimeIf`. We need a dedicated method to compile link-time expression trees, which does incur some duplication, unfortunately. Other than that, `linkTimeIf` is straightforward, by itself. The problem is that the whole point of `linkTimeIf` is that we can refer to *link-time properties*, and not just literals. However, our link-time properties are all hidden behind regular method calls, such as `LinkInfo.esVersion`. For optimizer-based branching with `if`s, that is fine, as the method is always inlined, and the optimizer can then see the constant. However, for `linkTimeIf`, that does not work, as it does not follow the requirements of a link-time expression tree. If we were on Scala 3 only, we could declare `esVersion` and its friends as an `inline def`, as follows: inline def esVersion: Int = linkTimePropertyInt("core/esVersion") The `inline` keyword is guaranteed by the language to be resolved at *compile*-time. Since the `linkTimePropertyInt` method is itself a primitive replaced by a `LinkTimeProperty`, by the time we reach our backend, we would see the latter, and all would be well. The same cannot be said for the `@inline` optimizer hint, which is all we have. We therefore add another language-level feature: `@linkTimeProperty`. This annotation can (currently) only be used in our own library. By contract, it must only be used on a method whose body is the corresponding `linkTimePropertyX` primitive. With it, we can define `esVersion` as: @inline @linkTimeProperty("core/esVersion") def esVersion: Int = linkTimePropertyInt("core/esVersion") That annotation makes the body public, in a way. That means the compiler back-end can now replace *call sites* to `esVersion` by the `LinkTimeProperty`. Semantically, `@linkTimeProperty` does nothing more than guaranteed inlining (with strong restrictions on the shape of body). Co-authored-by: Rikito Taniguchi --- .../org/scalajs/nscplugin/GenJSCode.scala | 87 ++++++++++ .../org/scalajs/nscplugin/JSDefinitions.scala | 3 + .../org/scalajs/nscplugin/JSPrimitives.scala | 4 +- .../nscplugin/test/LinkTimeIfTest.scala | 109 +++++++++++++ .../main/scala/org/scalajs/ir/Hashers.scala | 7 + .../main/scala/org/scalajs/ir/Printers.scala | 9 ++ .../scala/org/scalajs/ir/Serializers.scala | 16 +- .../src/main/scala/org/scalajs/ir/Tags.scala | 3 + .../scala/org/scalajs/ir/Transformers.scala | 3 + .../scala/org/scalajs/ir/Traversers.scala | 5 + .../src/main/scala/org/scalajs/ir/Trees.scala | 32 ++++ .../scala/org/scalajs/ir/PrintersTest.scala | 55 +++++++ .../scala/scala/scalajs/LinkingInfo.scala | 47 +++++- .../scalajs/annotation/linkTimeProperty.scala | 33 ++++ .../scalajs/linker/analyzer/Analyzer.scala | 2 +- .../scalajs/linker/analyzer/InfoLoader.scala | 39 +++-- .../org/scalajs/linker/analyzer/Infos.scala | 150 +++++++++++------- .../backend/wasmemitter/FunctionEmitter.scala | 3 +- .../linker/checker/ClassDefChecker.scala | 70 +++++++- .../scalajs/linker/checker/FeatureSet.scala | 6 +- .../scalajs/linker/checker/IRChecker.scala | 35 +++- .../scalajs/linker/frontend/BaseLinker.scala | 5 +- .../scalajs/linker/frontend/Desugarer.scala | 18 ++- .../linker/frontend/LinkTimeEvaluator.scala | 129 +++++++++++++++ .../org/scalajs/linker/frontend/Refiner.scala | 5 +- .../frontend/optimizer/OptimizerCore.scala | 3 +- .../org/scalajs/linker/AnalyzerTest.scala | 125 ++++++++++++++- .../org/scalajs/linker/IRCheckerTest.scala | 1 + .../linker/checker/ClassDefCheckerTest.scala | 78 +++++++++ .../LinkTimeEvaluatorTest.scala | 102 ++++++++++++ .../linker/testutils/TestIRBuilder.scala | 1 + .../testsuite/library/LinkTimeIfTest.scala | 95 +++++++++++ 32 files changed, 1178 insertions(+), 102 deletions(-) create mode 100644 compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala create mode 100644 library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala create mode 100644 linker/shared/src/main/scala/org/scalajs/linker/frontend/LinkTimeEvaluator.scala create mode 100644 linker/shared/src/test/scala/org/scalajs/linker/frontend/modulesplitter/LinkTimeEvaluatorTest.scala create mode 100644 test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index e46b1dc14f..dc1348ea22 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -5511,6 +5511,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.UnaryOp(js.UnaryOp.UnwrapFromThrowable, js.UnaryOp(js.UnaryOp.CheckNotNull, genArgs1)) + case LINKTIME_IF => + // LinkingInfo.linkTimeIf(cond, thenp, elsep) + val cond = genLinkTimeExpr(args(0)) + val thenp = genExpr(args(1)) + val elsep = genExpr(args(2)) + val tpe = + if (isStat) jstpe.VoidType + else toIRType(tree.tpe) + js.LinkTimeIf(cond, thenp, elsep)(tpe) + case LINKTIME_PROPERTY => // LinkingInfo.linkTimePropertyXXX("...") val arg = genArgs1 @@ -5529,6 +5539,83 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } } + private def genLinkTimeExpr(tree: Tree): js.Tree = { + import scalaPrimitives._ + + implicit val pos = tree.pos + + def invalid(): js.Tree = { + reporter.error(tree.pos, + "Illegal expression in the condition of a linkTimeIf. " + + "Valid expressions are: boolean and int primitives; " + + "references to link-time properties; " + + "primitive operations on booleans; " + + "and comparisons on ints.") + js.BooleanLiteral(false) + } + + tree match { + case Literal(c) => + c.tag match { + case BooleanTag => js.BooleanLiteral(c.booleanValue) + case IntTag => js.IntLiteral(c.intValue) + case _ => invalid() + } + + case Apply(fun @ Select(receiver, _), args) => + fun.symbol.getAnnotation(LinkTimePropertyAnnotation) match { + case Some(annotation) => + val propName = annotation.constantAtIndex(0).get.stringValue + js.LinkTimeProperty(propName)(toIRType(tree.tpe)) + + case None if isPrimitive(fun.symbol) => + val code = getPrimitive(fun.symbol) + + def genLhs: js.Tree = genLinkTimeExpr(receiver) + def genRhs: js.Tree = genLinkTimeExpr(args.head) + + def unaryOp(op: js.UnaryOp.Code): js.Tree = + js.UnaryOp(op, genLhs) + def binaryOp(op: js.BinaryOp.Code): js.Tree = + js.BinaryOp(op, genLhs, genRhs) + + toIRType(receiver.tpe) match { + case jstpe.BooleanType => + (code: @switch) match { + case ZNOT => unaryOp(js.UnaryOp.Boolean_!) + case EQ => binaryOp(js.BinaryOp.Boolean_==) + case NE | XOR => binaryOp(js.BinaryOp.Boolean_!=) + case OR => binaryOp(js.BinaryOp.Boolean_|) + case AND => binaryOp(js.BinaryOp.Boolean_&) + case ZOR => js.LinkTimeIf(genLhs, js.BooleanLiteral(true), genRhs)(jstpe.BooleanType) + case ZAND => js.LinkTimeIf(genLhs, genRhs, js.BooleanLiteral(false))(jstpe.BooleanType) + case _ => invalid() + } + + case jstpe.IntType => + (code: @switch) match { + case EQ => binaryOp(js.BinaryOp.Int_==) + case NE => binaryOp(js.BinaryOp.Int_!=) + case LT => binaryOp(js.BinaryOp.Int_<) + case LE => binaryOp(js.BinaryOp.Int_<=) + case GT => binaryOp(js.BinaryOp.Int_>) + case GE => binaryOp(js.BinaryOp.Int_>=) + case _ => invalid() + } + + case _ => + invalid() + } + + case None => // if !isPrimitive + invalid() + } + + case _ => + invalid() + } + } + /** Gen JS code for a primitive JS call (to a method of a subclass of js.Any) * This is the typed Scala.js to JS bridge feature. Basically it boils * down to calling the method without name mangling. But other aspects diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala index 2b0c5590d9..58c4910233 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala @@ -135,10 +135,13 @@ trait JSDefinitions { lazy val Runtime_dynamicImport = getMemberMethod(RuntimePackageModule, newTermName("dynamicImport")) lazy val LinkingInfoModule = getRequiredModule("scala.scalajs.LinkingInfo") + lazy val LinkingInfo_linkTimeIf = getMemberMethod(LinkingInfoModule, newTermName("linkTimeIf")) lazy val LinkingInfo_linkTimePropertyBoolean = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyBoolean")) lazy val LinkingInfo_linkTimePropertyInt = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyInt")) lazy val LinkingInfo_linkTimePropertyString = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyString")) + lazy val LinkTimePropertyAnnotation = getRequiredClass("scala.scalajs.annotation.linkTimeProperty") + lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk") lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala index 90aa1b1513..cf6f896453 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala @@ -71,7 +71,8 @@ abstract class JSPrimitives { final val WRAP_AS_THROWABLE = JS_TRY_CATCH + 1 // js.special.wrapAsThrowable final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger - final val LINKTIME_PROPERTY = DEBUGGER + 1 // LinkingInfo.linkTimePropertyXXX + final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf + final val LINKTIME_PROPERTY = LINKTIME_IF + 1 // LinkingInfo.linkTimePropertyXXX final val LastJSPrimitiveCode = LINKTIME_PROPERTY @@ -128,6 +129,7 @@ abstract class JSPrimitives { addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE) addPrimitive(Special_debugger, DEBUGGER) + addPrimitive(LinkingInfo_linkTimeIf, LINKTIME_IF) addPrimitive(LinkingInfo_linkTimePropertyBoolean, LINKTIME_PROPERTY) addPrimitive(LinkingInfo_linkTimePropertyInt, LINKTIME_PROPERTY) addPrimitive(LinkingInfo_linkTimePropertyString, LINKTIME_PROPERTY) diff --git a/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala new file mode 100644 index 0000000000..881c0e9a2f --- /dev/null +++ b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala @@ -0,0 +1,109 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.nscplugin.test + +import util._ + +import org.junit.Test +import org.junit.Assert._ + +// scalastyle:off line.size.limit + +class LinkTimeIfTest extends TestHelpers { + override def preamble: String = "import scala.scalajs.LinkingInfo._" + + private final val IllegalLinkTimeIfArgMessage = { + "Illegal expression in the condition of a linkTimeIf. " + + "Valid expressions are: boolean and int primitives; " + + "references to link-time properties; " + + "primitive operations on booleans; " + + "and comparisons on ints." + } + + @Test + def linkTimeErrorInvalidOp(): Unit = { + """ + object A { + def foo = + linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + } + """ hasErrors + s""" + |newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + | ^ + """ + } + + @Test + def linkTimeErrorInvalidEntities(): Unit = { + """ + object A { + def foo(x: String) = { + val bar = 1 + linkTimeIf(bar == 0) { } { } + } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar == 0) { } { } + | ^ + """ + + // String comparison is a `BinaryOp.===`, which is not allowed + """ + object A { + def foo(x: String) = + linkTimeIf("foo" == x) { } { } + } + """ hasErrors + s""" + |newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf("foo" == x) { } { } + | ^ + """ + + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(bar || !bar) { } { } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar || !bar) { } { } + | ^ + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar || !bar) { } { } + | ^ + """ + } + + @Test + def linkTimeCondInvalidTree(): Unit = { + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(if (bar) true else false) { } { } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(if (bar) true else false) { } { } + | ^ + """ + } +} diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index ad94d65549..599e9e8c1c 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -206,6 +206,13 @@ object Hashers { mixTree(elsep) mixType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + mixTag(TagLinkTimeIf) + mixTree(cond) + mixTree(thenp) + mixTree(elsep) + mixType(tree.tpe) + case While(cond, body) => mixTag(TagWhile) mixTree(cond) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index c69ad1447c..9a05ed7788 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -93,6 +93,7 @@ object Printers { protected def printBlock(tree: Tree): Unit = { val trees = tree match { case Block(trees) => trees + case Skip() => Nil case _ => tree :: Nil } printBlock(trees) @@ -232,6 +233,14 @@ object Printers { printBlock(elsep) } + case LinkTimeIf(cond, thenp, elsep) => + print("link-time-if (") + print(cond) + print(") ") + printBlock(thenp) + print(" else ") + printBlock(elsep) + case While(cond, body) => print("while (") print(cond) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index 7cc64e28e1..628630dfa1 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -297,6 +297,11 @@ object Serializers { writeTree(cond); writeTree(thenp); writeTree(elsep) writeType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + writeTagAndPos(TagLinkTimeIf) + writeTree(cond); writeTree(thenp); writeTree(elsep) + writeType(tree.tpe) + case While(cond, body) => writeTagAndPos(TagWhile) writeTree(cond); writeTree(body) @@ -1196,9 +1201,14 @@ object Serializers { Assign(lhs.asInstanceOf[AssignLhs], rhs) - case TagReturn => Return(readTree(), readLabelName()) - case TagIf => If(readTree(), readTree(), readTree())(readType()) - case TagWhile => While(readTree(), readTree()) + case TagReturn => + Return(readTree(), readLabelName()) + case TagIf => + If(readTree(), readTree(), readTree())(readType()) + case TagLinkTimeIf => + LinkTimeIf(readTree(), readTree(), readTree())(readType()) + case TagWhile => + While(readTree(), readTree()) case TagDoWhile => if (!hacks.useBelow(13)) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index bc7d2982b0..dc2862b7ec 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -135,6 +135,9 @@ private[ir] object Tags { final val TagNewLambda = TagApplyTypedClosure + 1 final val TagJSAwait = TagNewLambda + 1 + // New in 1.20 + final val TagLinkTimeIf = TagJSAwait + 1 + // Tags for member defs final val TagFieldDef = 1 diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index 27d9086435..e95a154e1c 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -60,6 +60,9 @@ object Transformers { case If(cond, thenp, elsep) => If(transform(cond), transform(thenp), transform(elsep))(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + LinkTimeIf(transform(cond), transform(thenp), transform(elsep))(tree.tpe) + case While(cond, body) => While(transform(cond), transform(body)) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala index d5782da074..15c9da9093 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala @@ -48,6 +48,11 @@ object Traversers { traverse(thenp) traverse(elsep) + case LinkTimeIf(cond, thenp, elsep) => + traverse(cond) + traverse(thenp) + traverse(elsep) + case While(cond, body) => traverse(cond) traverse(body) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index ccc3b56196..23a2eb7118 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -168,6 +168,38 @@ object Trees { sealed case class If(cond: Tree, thenp: Tree, elsep: Tree)(val tpe: Type)( implicit val pos: Position) extends Tree + /** Link-time `if` expression. + * + * The `cond` must be a well-typed link-time tree of type `boolean`. + * + * A link-time tree is a `Tree` matching the following sub-grammar: + * + * {{{ + * link-time-tree ::= + * BooleanLiteral + * | IntLiteral + * | StringLiteral + * | LinkTimeProperty + * | UnaryOp(link-time-unary-op, link-time-tree) + * | BinaryOp(link-time-binary-op, link-time-tree, link-time-tree) + * | LinkTimeIf(link-time-tree, link-time-tree, link-time-tree) + * + * link-time-unary-op ::= + * Boolean_! + * + * link-time-binary-op ::= + * Boolean_== | Boolean_!= | Boolean_| | Boolean_& + * | Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= + * }}} + * + * Note: nested `LinkTimeIf` nodes in the `cond` are used to encode + * short-circuiting boolean `&&` and `||`, just like we do with regular + * `If` nodes. + */ + sealed case class LinkTimeIf(cond: Tree, thenp: Tree, elsep: Tree)( + val tpe: Type)(implicit val pos: Position) + extends Tree + sealed case class While(cond: Tree, body: Tree)( implicit val pos: Position) extends Tree { val tpe = cond match { diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index 060bf4fdb8..fd49eb406e 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -202,6 +202,61 @@ class PrintersTest { If(ref("x", BooleanType), ref("y", BooleanType), b(false))(BooleanType)) } + @Test def printLinkTimeIf(): Unit = { + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + | 6 + |} + """, + LinkTimeIf(b(true), i(5), i(6))(IntType)) + + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + |} + """, + LinkTimeIf(b(true), i(5), Skip())(VoidType)) + + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + | link-time-if (false) { + | 6 + | } else { + | 7 + | } + |} + """, + LinkTimeIf(b(true), i(5), LinkTimeIf(b(false), i(6), i(7))(IntType))(IntType)) + + assertPrintEquals( + """ + |link-time-if (x) { + | true + |} else { + | y + |} + """, + LinkTimeIf(ref("x", BooleanType), b(true), ref("y", BooleanType))(BooleanType)) + + assertPrintEquals( + """ + |link-time-if (x) { + | y + |} else { + | false + |} + """, + LinkTimeIf(ref("x", BooleanType), ref("y", BooleanType), b(false))(BooleanType)) + } + @Test def printWhile(): Unit = { assertPrintEquals( """ diff --git a/library/src/main/scala/scala/scalajs/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/LinkingInfo.scala index ea9d6c1a2f..0a7218fb44 100644 --- a/library/src/main/scala/scala/scalajs/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/LinkingInfo.scala @@ -12,6 +12,8 @@ package scala.scalajs +import scala.scalajs.annotation.linkTimeProperty + object LinkingInfo { /** Returns true if we are linking for production, false otherwise. @@ -42,7 +44,7 @@ object LinkingInfo { * * @see [[developmentMode]] */ - @inline + @inline @linkTimeProperty("core/productionMode") def productionMode: Boolean = linkTimePropertyBoolean("core/productionMode") @@ -120,7 +122,7 @@ object LinkingInfo { * useES2018Feature() * }}} */ - @inline + @inline @linkTimeProperty("core/esVersion") def esVersion: Int = linkTimePropertyInt("core/esVersion") @@ -218,7 +220,7 @@ object LinkingInfo { * implementationWithoutES2015Semantics() * }}} */ - @inline + @inline @linkTimeProperty("core/useECMAScript2015Semantics") def useECMAScript2015Semantics: Boolean = linkTimePropertyBoolean("core/useECMAScript2015Semantics") @@ -252,15 +254,50 @@ object LinkingInfo { * implementationOptimizedForJavaScript() * }}} */ - @inline + @inline @linkTimeProperty("core/isWebAssembly") def isWebAssembly: Boolean = linkTimePropertyBoolean("core/isWebAssembly") /** Version of the linker. */ - @inline + @inline @linkTimeProperty("core/linkerVersion") def linkerVersion: String = linkTimePropertyString("core/linkerVersion") + /** Link-time conditional branching. + * + * A `linkTimeIf` expression behaves like an `if`, but it is guaranteed to + * be resolved at link-time. This prevents the unused branch to be linked at + * all. It can therefore reference APIs or language features that would + * otherwise fail to link. + * + * The condition `cond` can be constructed using: + * + * - Calls to methods annotated with `@linkTimeProperty` + * - Integer or boolean constants + * - Binary operators that return a boolean value + * + * A typical use case is to leverage the `**` operator on JavaScript + * `bigint`s if it is available, and otherwise fall back on using Scala + * `BigInt`s. Indeed, the `**` operator refuses to link when the target + * `esVersion` is too low. + * + * {{{ + * // Returns true iff 2^x < 10^y, for x and y positive integers + * def compareTwoPowTenPow(x: Int, y: Int): Boolean = { + * import scala.scalajs.LinkingInfo._ + * linkTimeIf(esVersion >= ESVersion.ES2020) { + * // JS bigints are available, and a fortiori their ** operator + * (js.BigInt(2) ** js.BigInt(x)) < (js.BigInt(10) ** js.BigInt(y)) + * } { + * // Fall back on Scala's BigInt's, which use a lot more code size + * BigInt(2).pow(x) < BigInt(10).pow(y) + * } + * } + * }}} + */ + def linkTimeIf[T](cond: Boolean)(thenp: T)(elsep: T): T = + throw new Error("stub") + /** Constants for the value of `esVersion`. */ object ESVersion { /** ECMAScrîpt 5.1. */ diff --git a/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala b/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala new file mode 100644 index 0000000000..6b93167c88 --- /dev/null +++ b/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala @@ -0,0 +1,33 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.scalajs.annotation + +/** Publicly marks the annotated method as being a link-time property. + * + * When an entity is annotated with `@linkTimeProperty`, its body must be a + * link-time property with the same `name`. The annotation makes that body + * "public", and it can therefore be inlined at call site at compile-time. + * + * From a user perspective, we can treat the presence of that annotation as if + * it were the `inline` keyword of Scala 3: it forces the inlining to happen + * at compile-time. + * + * This is necessary for the target method to be used in the condition of a + * `LinkingInfo.linkTimeIf`. + * + * @param name The name used to resolve the link-time value. + * + * @see [[LinkingInfo.linkTimeIf]] + */ +private[scalajs] final class linkTimeProperty(name: String) + extends scala.annotation.StaticAnnotation diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala index 22d3752fd4..c3b428dbeb 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala @@ -50,7 +50,7 @@ final class Analyzer(config: CommonPhaseConfig, initial: Boolean, private val linkTimeProperties = LinkTimeProperties.fromCoreSpec(config.coreSpec) private val infoLoader: InfoLoader = - new InfoLoader(irLoader, checkIRFor) + new InfoLoader(irLoader, checkIRFor, linkTimeProperties) def computeReachability(moduleInitializers: Seq[ModuleInitializer], symbolRequirements: SymbolRequirement, logger: Logger)(implicit ec: ExecutionContext): Future[Analysis] = { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala index 83003e6be5..c791727110 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala @@ -23,13 +23,16 @@ import org.scalajs.ir.Trees._ import org.scalajs.logging._ import org.scalajs.linker.checker._ -import org.scalajs.linker.frontend.IRLoader +import org.scalajs.linker.frontend.{IRLoader, LinkTimeProperties} import org.scalajs.linker.interface.LinkingException import org.scalajs.linker.CollectionsCompat.MutableMapCompatOps import Platform.emptyThreadSafeMap -private[analyzer] final class InfoLoader(irLoader: IRLoader, checkIRFor: Option[CheckingPhase]) { +private[analyzer] final class InfoLoader(irLoader: IRLoader, + checkIRFor: Option[CheckingPhase], linkTimeProperties: LinkTimeProperties) { + + private val generator = new Infos.InfoGenerator(linkTimeProperties) private var logger: Logger = _ private val cache = emptyThreadSafeMap[ClassName, InfoLoader.ClassInfoCache] @@ -44,7 +47,7 @@ private[analyzer] final class InfoLoader(irLoader: IRLoader, checkIRFor: Option[ implicit ec: ExecutionContext): Option[Future[Infos.ClassInfo]] = { if (irLoader.classExists(className)) { val infoCache = cache.getOrElseUpdate(className, - new InfoLoader.ClassInfoCache(className, irLoader, checkIRFor)) + new InfoLoader.ClassInfoCache(className, irLoader, checkIRFor, generator)) Some(infoCache.loadInfo(logger)) } else { None @@ -60,7 +63,9 @@ private[analyzer] final class InfoLoader(irLoader: IRLoader, checkIRFor: Option[ private[analyzer] object InfoLoader { private type MethodInfos = Array[Map[MethodName, Infos.MethodInfo]] - private class ClassInfoCache(className: ClassName, irLoader: IRLoader, checkIRFor: Option[CheckingPhase]) { + private class ClassInfoCache(className: ClassName, irLoader: IRLoader, + checkIRFor: Option[CheckingPhase], generator: Infos.InfoGenerator) { + private var cacheUsed: Boolean = false private var version: Version = Version.Unversioned private var info: Future[Infos.ClassInfo] = _ @@ -103,12 +108,12 @@ private[analyzer] object InfoLoader { } private def generateInfos(classDef: ClassDef): Infos.ClassInfo = { - val referencedFieldClasses = Infos.genReferencedFieldClasses(classDef.fields) + val referencedFieldClasses = generator.genReferencedFieldClasses(classDef.fields) - prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos) - prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo) + prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos, generator) + prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo, generator) prevJSMethodPropDefInfos = - genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos) + genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos, generator) val exportedMembers = prevJSCtorInfo.toList ::: prevJSMethodPropDefInfos @@ -116,7 +121,7 @@ private[analyzer] object InfoLoader { * and usually quite small when they exist. */ val topLevelExports = classDef.topLevelExportDefs - .map(Infos.generateTopLevelExportInfo(classDef.name.name, _)) + .map(generator.generateTopLevelExportInfo(classDef.name.name, _)) val jsNativeMembers = classDef.jsNativeMembers .map(m => m.name.name -> m.jsNativeLoadSpec).toMap @@ -136,7 +141,7 @@ private[analyzer] object InfoLoader { } private def genMethodInfos(methods: List[MethodDef], - prevMethodInfos: MethodInfos): MethodInfos = { + prevMethodInfos: MethodInfos, generator: Infos.InfoGenerator): MethodInfos = { val builders = Array.fill(MemberNamespace.Count)(Map.newBuilder[MethodName, Infos.MethodInfo]) @@ -144,7 +149,7 @@ private[analyzer] object InfoLoader { val info = prevMethodInfos(method.flags.namespace.ordinal) .get(method.methodName) .filter(_.version.sameVersion(method.version)) - .getOrElse(Infos.generateMethodInfo(method)) + .getOrElse(generator.generateMethodInfo(method)) builders(method.flags.namespace.ordinal) += method.methodName -> info } @@ -153,16 +158,18 @@ private[analyzer] object InfoLoader { } private def genJSCtorInfo(jsCtor: Option[JSConstructorDef], - prevJSCtorInfo: Option[Infos.ReachabilityInfo]): Option[Infos.ReachabilityInfo] = { + prevJSCtorInfo: Option[Infos.ReachabilityInfo], + generator: Infos.InfoGenerator): Option[Infos.ReachabilityInfo] = { jsCtor.map { ctor => prevJSCtorInfo .filter(_.version.sameVersion(ctor.version)) - .getOrElse(Infos.generateJSConstructorInfo(ctor)) + .getOrElse(generator.generateJSConstructorInfo(ctor)) } } private def genJSMethodPropDefInfos(jsMethodProps: List[JSMethodPropDef], - prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo]): List[Infos.ReachabilityInfo] = { + prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo], + generator: Infos.InfoGenerator): List[Infos.ReachabilityInfo] = { /* For JS method and property definitions, we use their index in the list of * `linkedClass.exportedMembers` as their identity. We cannot use their name * because the name itself is a `Tree`. @@ -176,13 +183,13 @@ private[analyzer] object InfoLoader { if (prevJSMethodPropDefInfos.size != jsMethodProps.size) { // Regenerate everything. - jsMethodProps.map(Infos.generateJSMethodPropDefInfo(_)) + jsMethodProps.map(generator.generateJSMethodPropDefInfo(_)) } else { for { (prevInfo, member) <- prevJSMethodPropDefInfos.zip(jsMethodProps) } yield { if (prevInfo.version.sameVersion(member.version)) prevInfo - else Infos.generateJSMethodPropDefInfo(member) + else generator.generateJSMethodPropDefInfo(member) } } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index fe957ca837..00b40402fe 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -22,8 +22,7 @@ import org.scalajs.ir.Types._ import org.scalajs.ir.Version import org.scalajs.ir.WellKnownNames._ -import org.scalajs.linker.backend.emitter.Transients._ -import org.scalajs.linker.standard.LinkedTopLevelExport +import org.scalajs.linker.frontend.{LinkTimeEvaluator, LinkTimeProperties} import org.scalajs.linker.standard.ModuleSet.ModuleID object Infos { @@ -184,27 +183,6 @@ object Infos { val methodName: MethodName ) extends MemberReachabilityInfo - def genReferencedFieldClasses(fields: List[AnyFieldDef]): Map[FieldName, ClassName] = { - val builder = Map.newBuilder[FieldName, ClassName] - - fields.foreach { - case FieldDef(flags, FieldIdent(name), _, ftpe) => - if (!flags.namespace.isStatic) { - ftpe match { - case ClassType(cls, _) => - builder += name -> cls - case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => - builder += name -> cls - case _ => - } - } - case _: JSFieldDef => - // Nothing to do. - } - - builder.result() - } - final class ReachabilityInfoBuilder(version: Version) { import ReachabilityInfoBuilder._ private val byClass = mutable.Map.empty[ClassName, ReachabilityInfoInClassBuilder] @@ -415,8 +393,11 @@ object Infos { def addUsedClassSuperClass(): this.type = setFlag(ReachabilityInfo.FlagUsedClassSuperClass) - def addReferencedLinkTimeProperty(linkTimeProperty: LinkTimeProperty): this.type = { + def markNeedsDesugaring(): this.type = setFlag(ReachabilityInfo.FlagNeedsDesugaring) + + def addReferencedLinkTimeProperty(linkTimeProperty: LinkTimeProperty): this.type = { + markNeedsDesugaring() linkTimeProperties.append((linkTimeProperty.name, linkTimeProperty.tpe)) this } @@ -539,46 +520,71 @@ object Infos { } } - /** Generates the [[MethodInfo]] of a - * [[org.scalajs.ir.Trees.MethodDef Trees.MethodDef]]. - */ - def generateMethodInfo(methodDef: MethodDef): MethodInfo = - new GenInfoTraverser(methodDef.version).generateMethodInfo(methodDef) + final class InfoGenerator(linkTimeProperties: LinkTimeProperties) { + def genReferencedFieldClasses(fields: List[AnyFieldDef]): Map[FieldName, ClassName] = { + val builder = Map.newBuilder[FieldName, ClassName] + + fields.foreach { + case FieldDef(flags, FieldIdent(name), _, ftpe) => + if (!flags.namespace.isStatic) { + ftpe match { + case ClassType(cls, _) => + builder += name -> cls + case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => + builder += name -> cls + case _ => + } + } + case _: JSFieldDef => + // Nothing to do. + } - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSConstructorDef Trees.JSConstructorDef]]. - */ - def generateJSConstructorInfo(ctorDef: JSConstructorDef): ReachabilityInfo = - new GenInfoTraverser(ctorDef.version).generateJSConstructorInfo(ctorDef) + builder.result() + } - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSMethodDef Trees.JSMethodDef]]. - */ - def generateJSMethodInfo(methodDef: JSMethodDef): ReachabilityInfo = - new GenInfoTraverser(methodDef.version).generateJSMethodInfo(methodDef) + /** Generates the [[MethodInfo]] of a + * [[org.scalajs.ir.Trees.MethodDef Trees.MethodDef]]. + */ + def generateMethodInfo(methodDef: MethodDef): MethodInfo = + new GenInfoTraverser(methodDef.version, linkTimeProperties).generateMethodInfo(methodDef) - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSPropertyDef Trees.JSPropertyDef]]. - */ - def generateJSPropertyInfo(propertyDef: JSPropertyDef): ReachabilityInfo = - new GenInfoTraverser(propertyDef.version).generateJSPropertyInfo(propertyDef) + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSConstructorDef Trees.JSConstructorDef]]. + */ + def generateJSConstructorInfo(ctorDef: JSConstructorDef): ReachabilityInfo = + new GenInfoTraverser(ctorDef.version, linkTimeProperties).generateJSConstructorInfo(ctorDef) - def generateJSMethodPropDefInfo(member: JSMethodPropDef): ReachabilityInfo = member match { - case methodDef: JSMethodDef => generateJSMethodInfo(methodDef) - case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef) - } + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSMethodDef Trees.JSMethodDef]]. + */ + def generateJSMethodInfo(methodDef: JSMethodDef): ReachabilityInfo = + new GenInfoTraverser(methodDef.version, linkTimeProperties).generateJSMethodInfo(methodDef) + + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSPropertyDef Trees.JSPropertyDef]]. + */ + def generateJSPropertyInfo(propertyDef: JSPropertyDef): ReachabilityInfo = + new GenInfoTraverser(propertyDef.version, linkTimeProperties).generateJSPropertyInfo(propertyDef) - /** Generates the [[MethodInfo]] for the top-level exports. */ - def generateTopLevelExportInfo(enclosingClass: ClassName, - topLevelExportDef: TopLevelExportDef): TopLevelExportInfo = { - val info = new GenInfoTraverser(Version.Unversioned) - .generateTopLevelExportInfo(enclosingClass, topLevelExportDef) - new TopLevelExportInfo(info, - ModuleID(topLevelExportDef.moduleID), - topLevelExportDef.topLevelExportName) + def generateJSMethodPropDefInfo(member: JSMethodPropDef): ReachabilityInfo = member match { + case methodDef: JSMethodDef => generateJSMethodInfo(methodDef) + case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef) + } + + /** Generates the [[MethodInfo]] for the top-level exports. */ + def generateTopLevelExportInfo(enclosingClass: ClassName, + topLevelExportDef: TopLevelExportDef): TopLevelExportInfo = { + val info = new GenInfoTraverser(Version.Unversioned, linkTimeProperties) + .generateTopLevelExportInfo(enclosingClass, topLevelExportDef) + new TopLevelExportInfo(info, + ModuleID(topLevelExportDef.moduleID), + topLevelExportDef.topLevelExportName) + } } - private final class GenInfoTraverser(version: Version) extends Traverser { + private final class GenInfoTraverser(version: Version, + linkTimeProperties: LinkTimeProperties) extends Traverser { + private val builder = new ReachabilityInfoBuilder(version) /** Whether we are currently in the body of an `async` closure. @@ -684,6 +690,36 @@ object Infos { // Capture values are in the enclosing scope; not the scope of the closure captureValues.foreach(traverse(_)) + // Do not call super.traverse(), as we must follow a single branch + case LinkTimeIf(cond, thenp, elsep) => + builder.markNeedsDesugaring() + traverse(cond) + LinkTimeEvaluator.tryEvalLinkTimeBooleanExpr(linkTimeProperties, cond) match { + case Some(result) => + if (result) + traverse(thenp) + else + traverse(elsep) + case None => + /* Ignore. Recall that we *assume* here that the ClassDef is + * valid on its own, i.e., it would pass the ClassDefChecker + * (irrespective of whether we actually run that checker). + * + * Under that assumption, the only failure mode for evaluating + * the `cond` is that it refers to a `LinkTimeProperty` that + * does not exist or has the wrong type. In that case, the + * analyzer will report a linking error at least for that + * `LinkTimeProperty` inside the `cond` (which we always + * traverse). + * + * If the assumption is broken and the evaluation failure was + * due to an ill-formed or ill-typed `cond`, then Desugar will + * eventually crash (with a message suggesting to enable checking + * the IR). + */ + () + } + // In all other cases, we'll have to call super.traverse() case _ => tree match { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index a86c55909e..7cf164c228 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -634,7 +634,8 @@ private class FunctionEmitter private ( // Transients (only generated by the optimizer) case t: Transient => genTransient(t) - case _:JSSuperConstructorCall | _:LinkTimeProperty | _:NewLambda => + case _:JSSuperConstructorCall | _:LinkTimeProperty | _:LinkTimeIf | + _:NewLambda => throw new AssertionError(s"Invalid tree: $tree") } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index a1c9f6363d..2d1437ee5f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -761,6 +761,13 @@ private final class ClassDefChecker(classDef: ClassDef, checkTree(thenp, env) checkTree(elsep, env) + case LinkTimeIf(cond, thenp, elsep) => + if (!featureSet.supports(FeatureSet.LinkTimeNodes)) + reportError(i"Illegal link-time if after desugaring") + checkLinkTimeTree(cond, BooleanType) + checkTree(thenp, env) + checkTree(elsep, env) + case While(cond, body) => checkTree(cond, env) checkTree(body, env) @@ -923,9 +930,16 @@ private final class ClassDefChecker(classDef: ClassDef, } case LinkTimeProperty(name) => - if (!featureSet.supports(FeatureSet.LinkTimeProperty)) + if (!featureSet.supports(FeatureSet.LinkTimeNodes)) reportError(i"Illegal link-time property '$name' after desugaring") + tree.tpe match { + case BooleanType | IntType | StringType => + () // ok + case tpe => + reportError(i"$tpe is not a valid type for LinkTimeProperty") + } + // JavaScript expressions case JSNew(ctor, args) => @@ -1091,6 +1105,60 @@ private final class ClassDefChecker(classDef: ClassDef, } } + private def checkLinkTimeTree(tree: Tree, expectedType: PrimType): Unit = { + implicit val ctx = ErrorContext(tree) + + /* For link-time trees, we need to check the types. Having a well-typed + * condition is required for `LinkTimeIf` to be resolved, and that happens + * before IR checking. Fortunately, only trivial primitive types can appear + * in link-time trees, and it is therefore possible to check them now. + */ + if (tree.tpe != expectedType) + reportError(i"$expectedType expected but ${tree.tpe} found in link-time tree") + + /* Unlike the evaluation algorithm, at this time we allow LinkTimeProperty's + * that are not actually available. We only check that their declared type + * matches the expected type. If it does not exist or does not have the + * type it was declared with, that constitutes a *linking error*, but it + * does not make the ClassDef invalid. + */ + + tree match { + case _:IntLiteral | _:BooleanLiteral | _:StringLiteral | _:LinkTimeProperty => + () // ok + + case UnaryOp(op, lhs) => + import UnaryOp._ + op match { + case Boolean_! => + checkLinkTimeTree(lhs, BooleanType) + case _ => + reportError(i"illegal unary op $op in link-time tree") + } + + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + op match { + case Boolean_== | Boolean_!= | Boolean_| | Boolean_& => + checkLinkTimeTree(lhs, BooleanType) + checkLinkTimeTree(rhs, BooleanType) + case Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= => + checkLinkTimeTree(lhs, IntType) + checkLinkTimeTree(rhs, IntType) + case _ => + reportError(i"illegal binary op $op in link-time tree") + } + + case LinkTimeIf(cond, thenp, elsep) => + checkLinkTimeTree(cond, BooleanType) + checkLinkTimeTree(thenp, expectedType) + checkLinkTimeTree(elsep, expectedType) + + case _ => + reportError(i"illegal tree of class ${tree.getClass().getName()} in link-time tree") + } + } + private def checkArrayType(tpe: ArrayType)( implicit ctx: ErrorContext): Unit = { checkArrayTypeRef(tpe.arrayTypeRef) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala index 33cbeaa135..94aabffff1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala @@ -36,8 +36,8 @@ private[checker] object FeatureSet { // Individual features - /** The `LinkTimeProperty` IR node. */ - val LinkTimeProperty = new FeatureSet(1 << 0) + /** Link-time IR nodes: `LinkTimeProperty` and `LinkTimeIf`. */ + val LinkTimeNodes = new FeatureSet(1 << 0) /** The `NewLambda` IR node. */ val NewLambda = new FeatureSet(1 << 1) @@ -84,7 +84,7 @@ private[checker] object FeatureSet { /** Features that must be desugared away. */ private val NeedsDesugaring = - LinkTimeProperty | NewLambda + LinkTimeNodes | NewLambda /** IR that is only the result of desugaring (currently empty). */ private val Desugared = diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index b66dfeea1f..3f87f8be04 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -24,13 +24,13 @@ import org.scalajs.ir.WellKnownNames._ import org.scalajs.logging._ -import org.scalajs.linker.frontend.LinkingUnit +import org.scalajs.linker.frontend.{LinkingUnit, LinkTimeEvaluator, LinkTimeProperties} import org.scalajs.linker.standard.LinkedClass import org.scalajs.linker.checker.ErrorReporter._ /** Checker for the validity of the IR. */ -private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, - previousPhase: CheckingPhase) { +private final class IRChecker(linkTimeProperties: LinkTimeProperties, + unit: LinkingUnit, reporter: ErrorReporter, previousPhase: CheckingPhase) { import IRChecker._ import reporter.reportError @@ -315,6 +315,26 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheckExpect(thenp, env, tpe) typecheckExpect(elsep, env, tpe) + case LinkTimeIf(cond, thenp, elsep) if featureSet.supports(FeatureSet.LinkTimeNodes) => + /* The `cond` is entirely checked in ClassDefChecker. + * + * We must only check the branch that is actually selected. + * We *cannot* check the dropped branch, because it may refer to types + * that are dropped by the reachability analysis (which is the whole + * point of LinkTimeIf). It is OK to have ill-typed IR in the dropped + * branch, because it is guaranteed to disappear during desugaring, + * before types are relied upon for any optimization or emission. + */ + LinkTimeEvaluator.tryEvalLinkTimeBooleanExpr(linkTimeProperties, cond) match { + case Some(value) => + if (value) + typecheckExpect(thenp, env, tree.tpe) + else + typecheckExpect(elsep, env, tree.tpe) + case None => + reportError(i"could not evaluate link-time condition: $cond") + } + case While(cond, body) => typecheckExpect(cond, env, BooleanType) typecheck(body, env) @@ -609,7 +629,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheckAny(expr, env) checkIsAsInstanceTargetType(tpe) - case LinkTimeProperty(name) if featureSet.supports(FeatureSet.LinkTimeProperty) => + case LinkTimeProperty(name) if featureSet.supports(FeatureSet.LinkTimeNodes) => // JavaScript expressions @@ -793,7 +813,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, } case _:RecordSelect | _:RecordValue | _:Transient | - _:JSSuperConstructorCall | _:LinkTimeProperty | + _:JSSuperConstructorCall | _:LinkTimeProperty | _:LinkTimeIf | _:ApplyTypedClosure | _:NewLambda => reportError("invalid tree") } @@ -963,9 +983,10 @@ object IRChecker { * * @return Count of IR checking errors (0 in case of success) */ - def check(unit: LinkingUnit, logger: Logger, previousPhase: CheckingPhase): Int = { + def check(linkTimeProperties: LinkTimeProperties, unit: LinkingUnit, + logger: Logger, previousPhase: CheckingPhase): Int = { val reporter = new LoggerErrorReporter(logger) - new IRChecker(unit, reporter, previousPhase).check() + new IRChecker(linkTimeProperties, unit, reporter, previousPhase).check() reporter.errorCount } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala index 62d05ff87e..b88ea4fd55 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala @@ -35,6 +35,8 @@ import Analysis._ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { import BaseLinker._ + private val linkTimeProperties = LinkTimeProperties.fromCoreSpec(config.coreSpec) + private val irLoader = new FileIRLoader private val analyzer = { val checkIRFor = if (checkIR) Some(CheckingPhase.Compiler) else None @@ -58,7 +60,8 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { } yield { if (checkIR) { logger.time("Linker: Check IR") { - val errorCount = IRChecker.check(linkResult, logger, CheckingPhase.BaseLinker) + val errorCount = IRChecker.check(linkTimeProperties, linkResult, + logger, CheckingPhase.BaseLinker) if (errorCount != 0) { throw new LinkingException( s"There were $errorCount IR checking errors.") diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala index 57f8eeb366..b97423440d 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala @@ -43,7 +43,8 @@ final class Desugarer(config: CommonPhaseConfig, checkIR: Boolean) { if (checkIR) { logger.time("Desugarer: Check IR") { - val errorCount = IRChecker.check(result, logger, CheckingPhase.Desugarer) + val errorCount = IRChecker.check(linkTimeProperties, result, logger, + CheckingPhase.Desugarer) if (errorCount != 0) { throw new AssertionError( s"There were $errorCount IR checking errors after desugaring (this is a Scala.js bug)") @@ -149,6 +150,21 @@ private[linker] object Desugarer { case LinkTimeProperties.LinkTimeString(value) => StringLiteral(value) } + case LinkTimeIf(cond, thenp, elsep) => + LinkTimeEvaluator.tryEvalLinkTimeBooleanExpr(linkTimeProperties, cond) match { + case Some(result) => + if (result) + transform(thenp) + else + transform(elsep) + case None => + throw new AssertionError( + s"Invalid link-time condition should not have passed the reachability analysis:\n" + + s"${tree.show}\n" + + s"at ${tree.pos}.\n" + + "Consider running the linker with `withCheckIR(true)` before submitting a bug report.") + } + case NewLambda(descriptor, fun) => implicit val pos = tree.pos val (className, ctorName) = syntheticLambdaNamesFor(descriptor) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/LinkTimeEvaluator.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/LinkTimeEvaluator.scala new file mode 100644 index 0000000000..3ab224306f --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/LinkTimeEvaluator.scala @@ -0,0 +1,129 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.frontend + +import org.scalajs.ir.Position +import org.scalajs.ir.Trees._ +import org.scalajs.ir.Trees.LinkTimeProperty._ + +import org.scalajs.linker.frontend.LinkTimeProperties._ +import org.scalajs.linker.interface.LinkingException + +private[linker] object LinkTimeEvaluator { + + /** Try and evaluate a link-time expression tree as a boolean value. + * + * This method assumes that the given `tree` is valid according to the + * `ClassDefChecker` and that its `tpe` is `BooleanType`. + * If that is not the case, it may throw or return an arbitrary result. + * + * Returns `None` if any subtree that needed evaluation was a missing + * `LinkTimeProperty` or one with the wrong type (i.e., one that would not + * pass the reachability analysis). + */ + def tryEvalLinkTimeBooleanExpr( + linkTimeProperties: LinkTimeProperties, tree: Tree): Option[Boolean] = { + implicit val pos = tree.pos + + tryEvalLinkTimeExpr(linkTimeProperties, tree).map(booleanValue(_)) + } + + /** Try and evaluate a link-time expression tree. + * + * This method assumes that the given `tree` is valid according to the + * `ClassDefChecker`. + * If that is not the case, it may throw or return an arbitrary result. + * + * Returns `None` if any subtree that needed evaluation was a missing + * `LinkTimeProperty` or one with the wrong type (i.e., one that would not + * pass the reachability analysis). + */ + private def tryEvalLinkTimeExpr( + props: LinkTimeProperties, tree: Tree): Option[LinkTimeValue] = { + implicit val pos = tree.pos + + tree match { + case IntLiteral(value) => Some(LinkTimeInt(value)) + case BooleanLiteral(value) => Some(LinkTimeBoolean(value)) + case StringLiteral(value) => Some(LinkTimeString(value)) + + case LinkTimeProperty(name) => + props.get(name).filter(_.tpe == tree.tpe) + + case UnaryOp(op, lhs) => + import UnaryOp._ + for { + l <- tryEvalLinkTimeExpr(props, lhs) + } yield { + op match { + case Boolean_! => LinkTimeBoolean(!booleanValue(l)) + + case _ => + throw new LinkingException( + s"Illegal unary op $op in link-time tree at $pos") + } + } + + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + for { + l <- tryEvalLinkTimeExpr(props, lhs) + r <- tryEvalLinkTimeExpr(props, rhs) + } yield { + op match { + case Boolean_== => LinkTimeBoolean(booleanValue(l) == booleanValue(r)) + case Boolean_!= => LinkTimeBoolean(booleanValue(l) != booleanValue(r)) + case Boolean_| => LinkTimeBoolean(booleanValue(l) | booleanValue(r)) + case Boolean_& => LinkTimeBoolean(booleanValue(l) & booleanValue(r)) + + case Int_== => LinkTimeBoolean(intValue(l) == intValue(r)) + case Int_!= => LinkTimeBoolean(intValue(l) != intValue(r)) + case Int_< => LinkTimeBoolean(intValue(l) < intValue(r)) + case Int_<= => LinkTimeBoolean(intValue(l) <= intValue(r)) + case Int_> => LinkTimeBoolean(intValue(l) > intValue(r)) + case Int_>= => LinkTimeBoolean(intValue(l) >= intValue(r)) + + case _ => + throw new LinkingException( + s"Illegal binary op $op in link-time tree at $pos") + } + } + + case LinkTimeIf(cond, thenp, elsep) => + tryEvalLinkTimeExpr(props, cond).flatMap { c => + if (booleanValue(c)) + tryEvalLinkTimeExpr(props, thenp) + else + tryEvalLinkTimeExpr(props, elsep) + } + + case _ => + throw new LinkingException( + s"Illegal tree of class ${tree.getClass().getName()} in link-time tree at $pos") + } + } + + private def intValue(value: LinkTimeValue)(implicit pos: Position): Int = value match { + case LinkTimeInt(value) => + value + case _ => + throw new LinkingException(s"Value of type int expected but got $value at $pos") + } + + private def booleanValue(value: LinkTimeValue)(implicit pos: Position): Boolean = value match { + case LinkTimeBoolean(value) => + value + case _ => + throw new LinkingException(s"Value of type boolean expected but got $value at $pos") + } +} diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala index 0f074adf55..4f778351ba 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala @@ -30,6 +30,8 @@ import org.scalajs.linker.analyzer._ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) { import Refiner._ + private val linkTimeProperties = LinkTimeProperties.fromCoreSpec(config.coreSpec) + private val irLoader = new ClassDefIRLoader private val analyzer = { val checkIRFor = if (checkIR) Some(CheckingPhase.Optimizer) else None @@ -81,7 +83,8 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) { if (shouldRunIRChecker) { logger.time("Refiner: Check IR") { - val errorCount = IRChecker.check(result, logger, CheckingPhase.Optimizer) + val errorCount = IRChecker.check(linkTimeProperties, result, logger, + CheckingPhase.Optimizer) if (errorCount != 0) { throw new AssertionError( s"There were $errorCount IR checking errors after optimization (this is a Scala.js bug)") diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 51cebcdcca..9f7fe1aa95 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -689,7 +689,8 @@ private[optimizer] abstract class OptimizerCore( _:JSGlobalRef | _:JSTypeOfGlobalRef | _:Literal => tree - case _:LinkTimeProperty | _:NewLambda | _:RecordSelect | _:Transient => + case _:LinkTimeProperty | _:LinkTimeIf | _:NewLambda | _:RecordSelect | + _:Transient => throw new IllegalArgumentException( s"Invalid tree in transform of class ${tree.getClass.getName}: $tree") } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala index 4eb535144d..c543be0f2b 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala @@ -874,6 +874,114 @@ class AnalyzerTest { ) Future.sequence(results) } + + @Test + def linkTimeIfReachable(): AsyncResult = await { + val mainMethodName = m("main", Nil, IntRef) + val fooMethodName = m("foo", Nil, IntRef) + val barMethodName = m("bar", Nil, IntRef) + + val thisType = ClassType("A", nullable = false) + + val productionMode = true + + /* linkTimeIf(productionMode) { + * this.foo() + * } { + * this.bar() + * } + */ + val mainBody = LinkTimeIf( + BinaryOp(BinaryOp.Boolean_==, + LinkTimeProperty("core/productionMode")(BooleanType), + BooleanLiteral(productionMode)), + Apply(EAF, This()(thisType), fooMethodName, Nil)(IntType), + Apply(EAF, This()(thisType), barMethodName, Nil)(IntType) + )(IntType) + + val classDefs = Seq( + classDef("A", superClass = Some(ObjectClass), + methods = List( + trivialCtor("A"), + MethodDef(EMF, mainMethodName, NON, Nil, IntType, Some(mainBody))(EOH, UNV), + MethodDef(EMF, fooMethodName, NON, Nil, IntType, Some(int(1)))(EOH, UNV), + MethodDef(EMF, barMethodName, NON, Nil, IntType, Some(int(2)))(EOH, UNV) + ) + ) + ) + + val requirements = { + reqsFactory.instantiateClass("A", NoArgConstructorName) ++ + reqsFactory.callMethod("A", mainMethodName) + } + + val analysisFuture = computeAnalysis(classDefs, requirements, + config = StandardConfig().withSemantics(_.withProductionMode(productionMode))) + + for (analysis <- analysisFuture) yield { + assertNoError(analysis) + + val AfooMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(fooMethodName) + assertTrue(AfooMethodInfo.isReachable) + + val AbarMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(barMethodName) + assertFalse(AbarMethodInfo.isReachable) + } + } + + @Test + def linkTimeIfError(): AsyncResult = await { + val mainMethodName = m("main", Nil, IntRef) + val fooMethodName = m("foo", Nil, IntRef) + + val thisType = ClassType("A", nullable = false) + + val productionMode = true + + /* linkTimeIf(unknownProperty) { + * this.foo() + * } { + * this.bar() + * } + */ + val mainBody = LinkTimeIf( + BinaryOp(BinaryOp.Boolean_==, + LinkTimeProperty("core/unknownProperty")(BooleanType), + BooleanLiteral(productionMode)), + Apply(EAF, This()(thisType), fooMethodName, Nil)(IntType), + Apply(EAF, This()(thisType), fooMethodName, Nil)(IntType) + )(IntType) + + val classDefs = Seq( + classDef("A", superClass = Some(ObjectClass), + methods = List( + trivialCtor("A"), + MethodDef(EMF, mainMethodName, NON, Nil, IntType, Some(mainBody))(EOH, UNV) + ) + ) + ) + + val requirements = { + reqsFactory.instantiateClass("A", NoArgConstructorName) ++ + reqsFactory.callMethod("A", mainMethodName) + } + + val analysisFuture = computeAnalysis(classDefs, requirements, + config = StandardConfig().withSemantics(_.withProductionMode(productionMode))) + + for (analysis <- analysisFuture) yield { + assertContainsError(s"InvalidLinkTimeProperty(core/unknownProperty)", analysis) { + case InvalidLinkTimeProperty("core/unknownProperty", BooleanType, _) => true + } + + // Branches are not taken, so there is no error for linking `foo` + assertNotContainsError(s"any MissingMethod", analysis) { + case MissingMethod(_, _) => true + } + } + } } object AnalyzerTest { @@ -962,10 +1070,21 @@ object AnalyzerTest { private def assertContainsError(msg: String, analysis: Analysis)( pf: PartialFunction[Error, Boolean]): Unit = { - val fullMessage = s"Expected $msg, got ${analysis.errors}" - assertTrue(fullMessage, analysis.errors.exists { + assertTrue(s"Expected $msg, got ${analysis.errors}", + containsError(analysis)(pf)) + } + + private def assertNotContainsError(msg: String, analysis: Analysis)( + pf: PartialFunction[Error, Boolean]): Unit = { + assertFalse(s"Did not expect $msg, got ${analysis.errors}", + containsError(analysis)(pf)) + } + + private def containsError(analysis: Analysis)( + pf: PartialFunction[Error, Boolean]): Boolean = { + analysis.errors.exists { e => pf.applyOrElse(e, (_: Error) => false) - }) + } } object ClsInfo { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala index 73dce25631..1c6ae731b1 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala @@ -446,6 +446,7 @@ object IRCheckerTest { new ClassTransformer { override def transform(tree: Tree): Tree = tree match { case tree: LinkTimeProperty => zeroOf(tree.tpe) + case tree: LinkTimeIf => zeroOf(tree.tpe) case tree: NewLambda => UnaryOp(UnaryOp.Throw, Null()) case _ => super.transform(tree) } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala index 6441fd0c48..309cc5d7a1 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala @@ -834,6 +834,84 @@ class ClassDefCheckerTest { "Assignment to RecordSelect of illegal tree: org.scalajs.ir.Trees$IntLiteral", previousPhase = CheckingPhase.Optimizer) } + + @Test + def linkTimePropertyTest(): Unit = { + // Test that some illegal types are rejected + for (tpe <- List(FloatType, NullType, NothingType, ClassType(BoxedStringClass, nullable = false))) { + assertError( + mainTestClassDef(LinkTimeProperty("foo")(tpe)), + s"${tpe.show()} is not a valid type for LinkTimeProperty") + } + + // Some error also gets reported if used in link-time-tree position + assertError( + mainTestClassDef { + LinkTimeIf(LinkTimeProperty("foo")(NothingType), int(5), int(6))(IntType) + }, + s"boolean expected but nothing found in link-time tree") + + // LinkTimeProperty is rejected after desugaring + assertError( + mainTestClassDef(LinkTimeProperty("foo")(IntType)), + "Illegal link-time property 'foo' after desugaring", + previousPhase = CheckingPhase.Optimizer) + } + + @Test + def linkTimeIfTest(): Unit = { + def makeTestClassDef(cond: Tree): ClassDef = { + classDef( + "Foo", + superClass = Some(ObjectClass), + methods = List( + trivialCtor("Foo"), + MethodDef(EMF, MethodName("foo", Nil, VoidRef), NON, Nil, VoidType, Some { + LinkTimeIf( + cond, + consoleLog(StringLiteral("foo")), + consoleLog(StringLiteral("bar")) + )(VoidType) + })(EOH, UNV) + ) + ) + } + + assertError( + makeTestClassDef( + UnaryOp(UnaryOp.Boolean_!, int(0)) + ), + "boolean expected but int found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.Int_==, int(0), LinkTimeProperty("core/productionMode")(BooleanType)) + ), + "int expected but boolean found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.Boolean_==, int(0), LinkTimeProperty("core/productionMode")(BooleanType)) + ), + "boolean expected but int found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.===, int(0), int(1)) + ), + "illegal binary op 1 in link-time tree" + ) + + assertError( + makeTestClassDef( + If(BooleanLiteral(true), BooleanLiteral(true), BooleanLiteral(false))(BooleanType) + ), + "illegal tree of class org.scalajs.ir.Trees$If in link-time tree" + ) + } } private object ClassDefCheckerTest { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/frontend/modulesplitter/LinkTimeEvaluatorTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/frontend/modulesplitter/LinkTimeEvaluatorTest.scala new file mode 100644 index 0000000000..79e8d36ff6 --- /dev/null +++ b/linker/shared/src/test/scala/org/scalajs/linker/frontend/modulesplitter/LinkTimeEvaluatorTest.scala @@ -0,0 +1,102 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.frontend + +import org.junit.Test +import org.junit.Assert._ + +import org.scalajs.ir.Trees._ +import org.scalajs.ir.Types._ + +import org.scalajs.linker.interface.{ESFeatures, ESVersion, Semantics, StandardConfig} +import org.scalajs.linker.standard.CoreSpec +import org.scalajs.linker.testutils.TestIRBuilder._ + +class LinkTimeEvaluatorTest { + /** Convenience builder for `LinkTimeProperties` with mostly-default configs. */ + private def make( + semantics: Semantics => Semantics = identity, + esFeatures: ESFeatures => ESFeatures = identity, + isWebAssembly: Boolean = false + ): LinkTimeProperties = { + val config = StandardConfig() + .withSemantics(semantics) + .withESFeatures(esFeatures) + .withExperimentalUseWebAssembly(isWebAssembly) + LinkTimeProperties.fromCoreSpec(CoreSpec.fromStandardConfig(config)) + } + + @Test + def testTryEvalLinkTimeBooleanExpr(): Unit = { + val defaults = make() + + def test(expected: Option[Boolean], tree: Tree, config: LinkTimeProperties = defaults): Unit = + assertEquals(expected, LinkTimeEvaluator.tryEvalLinkTimeBooleanExpr(config, tree)) + + def testTrue(tree: Tree, config: LinkTimeProperties = defaults): Unit = + test(Some(true), tree, config) + + def testFalse(tree: Tree, config: LinkTimeProperties = defaults): Unit = + test(Some(false), tree, config) + + def testFail(tree: Tree, config: LinkTimeProperties = defaults): Unit = + test(None, tree, config) + + // Boolean literal + testTrue(bool(true)) + testFalse(bool(false)) + + // Boolean link-time property + testFalse(LinkTimeProperty("core/isWebAssembly")(BooleanType)) + testTrue(LinkTimeProperty("core/isWebAssembly")(BooleanType), make(isWebAssembly = true)) + testFail(LinkTimeProperty("core/missing")(BooleanType)) + testFail(LinkTimeProperty("core/esVersion")(BooleanType)) + + // Int comparison + for (l <- List(3, 5, 7); r <- List(3, 5, 7)) { + test(Some(l == r), BinaryOp(BinaryOp.Int_==, int(l), int(r))) + test(Some(l != r), BinaryOp(BinaryOp.Int_!=, int(l), int(r))) + test(Some(l < r), BinaryOp(BinaryOp.Int_<, int(l), int(r))) + test(Some(l <= r), BinaryOp(BinaryOp.Int_<=, int(l), int(r))) + test(Some(l > r), BinaryOp(BinaryOp.Int_>, int(l), int(r))) + test(Some(l >= r), BinaryOp(BinaryOp.Int_>=, int(l), int(r))) + } + + // Boolean operator + testTrue(UnaryOp(UnaryOp.Boolean_!, bool(false))) + testFalse(UnaryOp(UnaryOp.Boolean_!, bool(true))) + + // Comparison with link-time property + val esVersionProp = LinkTimeProperty("core/esVersion")(IntType) + testTrue(BinaryOp(BinaryOp.Int_>=, esVersionProp, int(ESVersion.ES2015.edition))) + testFalse(BinaryOp(BinaryOp.Int_>=, esVersionProp, int(ESVersion.ES2019.edition))) + testTrue(BinaryOp(BinaryOp.Int_>=, esVersionProp, int(ESVersion.ES2019.edition)), + make(esFeatures = _.withESVersion(ESVersion.ES2021))) + + // LinkTimeIf + testTrue(LinkTimeIf(bool(true), bool(true), bool(false))(BooleanType)) + testFalse(LinkTimeIf(bool(true), bool(false), bool(true))(BooleanType)) + testFalse(LinkTimeIf(bool(false), bool(true), bool(false))(BooleanType)) + + // Complex expression: esVersion >= ES2016 && esVersion <= ES2019 + val complexExpr = LinkTimeIf( + BinaryOp(BinaryOp.Int_>=, esVersionProp, int(ESVersion.ES2016.edition)), + BinaryOp(BinaryOp.Int_<=, esVersionProp, int(ESVersion.ES2019.edition)), + bool(false))( + BooleanType) + testTrue(complexExpr, make(esFeatures = _.withESVersion(ESVersion.ES2017))) + testTrue(complexExpr, make(esFeatures = _.withESVersion(ESVersion.ES2019))) + testFalse(complexExpr, make(esFeatures = _.withESVersion(ESVersion.ES2015))) + testFalse(complexExpr, make(esFeatures = _.withESVersion(ESVersion.ES2021))) + } +} diff --git a/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala b/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala index 7d022a5123..a4284ec897 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala @@ -196,6 +196,7 @@ object TestIRBuilder { implicit def methodName2MethodIdent(name: MethodName): MethodIdent = MethodIdent(name) + def bool(x: Boolean): BooleanLiteral = BooleanLiteral(x) def int(x: Int): IntLiteral = IntLiteral(x) def str(x: String): StringLiteral = StringLiteral(x) } diff --git a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala new file mode 100644 index 0000000000..1cca641fbf --- /dev/null +++ b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala @@ -0,0 +1,95 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.testsuite.library + +import scala.scalajs.js +import scala.scalajs.LinkingInfo._ + +import org.junit.Test +import org.junit.Assert._ +import org.junit.Assume._ + +import org.scalajs.testsuite.utils.Platform + +class LinkTimeIfTest { + @Test def linkTimeIfConst(): Unit = { + // boolean const + assertEquals(1, linkTimeIf(true) { 1 } { 2 }) + assertEquals(2, linkTimeIf(false) { 1 } { 2 }) + } + + @Test def linkTimeIfProp(): Unit = { + locally { + val cond = Platform.isInProductionMode + assertEquals(cond, linkTimeIf(productionMode) { true } { false }) + } + + locally { + val cond = !Platform.isInProductionMode + assertEquals(cond, linkTimeIf(!productionMode) { true } { false }) + } + } + + @Test def linkTimIfIntProp(): Unit = { + locally { + val cond = Platform.assumedESVersion >= ESVersion.ES2015 + assertEquals(cond, linkTimeIf(esVersion >= ESVersion.ES2015) { true } { false }) + } + + locally { + val cond = !(Platform.assumedESVersion < ESVersion.ES2015) + assertEquals(cond, linkTimeIf(!(esVersion < ESVersion.ES2015)) { true } { false }) + } + } + + @Test def linkTimeIfNested(): Unit = { + locally { + val cond = { + Platform.isInProductionMode && + Platform.assumedESVersion >= ESVersion.ES2015 + } + assertEquals(if (cond) 53 else 78, + linkTimeIf(productionMode && esVersion >= ESVersion.ES2015) { 53 } { 78 }) + } + + locally { + val cond = { + Platform.assumedESVersion >= ESVersion.ES2015 && + Platform.assumedESVersion < ESVersion.ES2019 && + Platform.isInProductionMode + } + val result = linkTimeIf(esVersion >= ESVersion.ES2015 && + esVersion < ESVersion.ES2019 && productionMode) { + 53 + } { + 78 + } + assertEquals(if (cond) 53 else 78, result) + } + } + + @Test def exponentOp(): Unit = { + def pow(x: Double, y: Double): Double = { + linkTimeIf(esVersion >= ESVersion.ES2016) { + assertTrue("Took the wrong branch of linkTimeIf when linking for ES 2016+", + esVersion >= ESVersion.ES2016) + (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]).asInstanceOf[Double] + } { + assertFalse("Took the wrong branch of linkTimeIf when linking for ES 2015-", + esVersion >= ESVersion.ES2016) + Math.pow(x, y) + } + } + assertEquals(pow(2.0, 8.0), 256.0, 0) + } +} From f0e7a337b03584b0bb2e7868153509fbd60a409b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 5 Jan 2025 19:19:20 +0100 Subject: [PATCH 3/3] Use JS bigint's if possible inside the `parseFloat` algorithm. We use a `linkTimeIf` to select a `bigint`-based implementation of `parseFloatDecimalCorrection` when they are supported. We need a `linkTimeIf` in this case because it uses the JS `**` operator, which does not link below ES 2016. The `bigint`-based implementation avoids bringing in the entire `BigInteger` implementation, which is a major code size win if that was the only reason `BigInteger` was needed. --- javalib/src/main/scala/java/lang/Float.scala | 82 ++++++++++++++++---- 1 file changed, 69 insertions(+), 13 deletions(-) diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index 8fa4ce3070..a2d54c77fd 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -13,9 +13,9 @@ package java.lang import java.lang.constant.{Constable, ConstantDesc} -import java.math.BigInteger import scala.scalajs.js +import scala.scalajs.LinkingInfo._ /* This is a hijacked class. Its instances are primitive numbers. * Constructors are not emitted. @@ -226,9 +226,23 @@ object Float { fractionalPartStr: String, exponentStr: String, zDown: scala.Float, zUp: scala.Float, mid: scala.Double): scala.Float = { + /* Get the best available implementation of big integers for the given platform. + * + * If JS bigint's are supported, use them. Otherwise fall back on + * `java.math.BigInteger`. + * + * We need a `linkTimeIf` here because the JS bigint implementation uses + * the `**` operator, which does not link when `esVersion < ESVersion.ES2016`. + */ + val bigIntImpl = linkTimeIf[BigIntImpl](esVersion >= ESVersion.ES2020) { + BigIntImpl.JSBigInt + } { + BigIntImpl.JBigInteger + } + // 1. Accurately parse the string with the representation f × 10ᵉ - val f: BigInteger = new BigInteger(integralPartStr + fractionalPartStr) + val f: bigIntImpl.Repr = bigIntImpl.fromString(integralPartStr + fractionalPartStr) val e: Int = Integer.parseInt(exponentStr) - fractionalPartStr.length() /* Note: we know that `e` is "reasonable" (in the range [-324, +308]). If @@ -261,24 +275,23 @@ object Float { val mExplicitBits = midBits & ((1L << mbits) - 1) val mImplicit1Bit = 1L << mbits // the implicit '1' bit of a normalized floating-point number - val m = BigInteger.valueOf(mExplicitBits | mImplicit1Bit) + val m = bigIntImpl.fromUnsignedLong53(mExplicitBits | mImplicit1Bit) val k = biasedK - bias - mbits // 3. Accurately compare f × 10ᵉ to m × 2ᵏ - @inline def compare(x: BigInteger, y: BigInteger): Int = - x.compareTo(y) + import bigIntImpl.{multiplyBy2Pow, multiplyBy10Pow} val cmp = if (e >= 0) { if (k >= 0) - compare(multiplyBy10Pow(f, e), multiplyBy2Pow(m, k)) + bigIntImpl.compare(multiplyBy10Pow(f, e), multiplyBy2Pow(m, k)) else - compare(multiplyBy2Pow(multiplyBy10Pow(f, e), -k), m) // this branch may be dead code in practice + bigIntImpl.compare(multiplyBy2Pow(multiplyBy10Pow(f, e), -k), m) // this branch may be dead code in practice } else { if (k >= 0) - compare(f, multiplyBy2Pow(multiplyBy10Pow(m, -e), k)) + bigIntImpl.compare(f, multiplyBy2Pow(multiplyBy10Pow(m, -e), k)) else - compare(multiplyBy2Pow(f, -k), multiplyBy10Pow(m, -e)) + bigIntImpl.compare(multiplyBy2Pow(f, -k), multiplyBy10Pow(m, -e)) } // 4. Choose zDown or zUp depending on the result of the comparison @@ -293,11 +306,54 @@ object Float { zUp } - @inline private def multiplyBy10Pow(v: BigInteger, e: Int): BigInteger = - v.multiply(BigInteger.TEN.pow(e)) + /** An implementation of big integer arithmetics that we need in the above method. */ + private sealed abstract class BigIntImpl { + type Repr + + def fromString(str: String): Repr + + /** Creates a big integer from a `Long` that needs at most 53 bits (unsigned). */ + def fromUnsignedLong53(x: scala.Long): Repr + + def multiplyBy2Pow(v: Repr, e: Int): Repr + def multiplyBy10Pow(v: Repr, e: Int): Repr + + def compare(x: Repr, y: Repr): Int + } + + private object BigIntImpl { + object JSBigInt extends BigIntImpl { + type Repr = js.BigInt + + @inline def fromString(str: String): Repr = js.BigInt(str) - @inline private def multiplyBy2Pow(v: BigInteger, e: Int): BigInteger = - v.shiftLeft(e) + // The 53-bit restriction guarantees that the conversion to `Double` is lossless. + @inline def fromUnsignedLong53(x: scala.Long): Repr = js.BigInt(x.toDouble) + + @inline def multiplyBy2Pow(v: Repr, e: Int): Repr = v << js.BigInt(e) + @inline def multiplyBy10Pow(v: Repr, e: Int): Repr = v * (js.BigInt(10) ** js.BigInt(e)) + + @inline def compare(x: Repr, y: Repr): Int = { + if (x < y) -1 + else if (x > y) 1 + else 0 + } + } + + object JBigInteger extends BigIntImpl { + import java.math.BigInteger + + type Repr = BigInteger + + @inline def fromString(str: String): Repr = new BigInteger(str) + @inline def fromUnsignedLong53(x: scala.Long): Repr = BigInteger.valueOf(x) + + @inline def multiplyBy2Pow(v: Repr, e: Int): Repr = v.shiftLeft(e) + @inline def multiplyBy10Pow(v: Repr, e: Int): Repr = v.multiply(BigInteger.TEN.pow(e)) + + @inline def compare(x: Repr, y: Repr): Int = x.compareTo(y) + } + } private def parseFloatHexadecimal(integralPartStr: String, fractionalPartStr: String, binaryExpStr: String): scala.Float = {