Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix #4997: Add linkTimeIf for link-time conditional branching. #5110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5511,6 +5511,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
js.UnaryOp(js.UnaryOp.UnwrapFromThrowable,
js.UnaryOp(js.UnaryOp.CheckNotNull, genArgs1))

case LINKTIME_IF =>
// LinkingInfo.linkTimeIf(cond, thenp, elsep)
val cond = genLinkTimeExpr(args(0))
val thenp = genExpr(args(1))
val elsep = genExpr(args(2))
val tpe =
if (isStat) jstpe.VoidType
else toIRType(tree.tpe)
js.LinkTimeIf(cond, thenp, elsep)(tpe)

case LINKTIME_PROPERTY =>
// LinkingInfo.linkTimePropertyXXX("...")
val arg = genArgs1
Expand All @@ -5529,6 +5539,83 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}
}

private def genLinkTimeExpr(tree: Tree): js.Tree = {
import scalaPrimitives._

implicit val pos = tree.pos

def invalid(): js.Tree = {
reporter.error(tree.pos,
"Illegal expression in the condition of a linkTimeIf. " +
"Valid expressions are: boolean and int primitives; " +
"references to link-time properties; " +
"primitive operations on booleans; " +
"and comparisons on ints.")
js.BooleanLiteral(false)
}

tree match {
case Literal(c) =>
c.tag match {
case BooleanTag => js.BooleanLiteral(c.booleanValue)
case IntTag => js.IntLiteral(c.intValue)
case _ => invalid()
}

case Apply(fun @ Select(receiver, _), args) =>
fun.symbol.getAnnotation(LinkTimePropertyAnnotation) match {
case Some(annotation) =>
val propName = annotation.constantAtIndex(0).get.stringValue
js.LinkTimeProperty(propName)(toIRType(tree.tpe))

case None if isPrimitive(fun.symbol) =>
val code = getPrimitive(fun.symbol)

def genLhs: js.Tree = genLinkTimeExpr(receiver)
def genRhs: js.Tree = genLinkTimeExpr(args.head)

def unaryOp(op: js.UnaryOp.Code): js.Tree =
js.UnaryOp(op, genLhs)
def binaryOp(op: js.BinaryOp.Code): js.Tree =
js.BinaryOp(op, genLhs, genRhs)

toIRType(receiver.tpe) match {
case jstpe.BooleanType =>
(code: @switch) match {
case ZNOT => unaryOp(js.UnaryOp.Boolean_!)
case EQ => binaryOp(js.BinaryOp.Boolean_==)
case NE | XOR => binaryOp(js.BinaryOp.Boolean_!=)
case OR => binaryOp(js.BinaryOp.Boolean_|)
case AND => binaryOp(js.BinaryOp.Boolean_&)
case ZOR => js.LinkTimeIf(genLhs, js.BooleanLiteral(true), genRhs)(jstpe.BooleanType)
case ZAND => js.LinkTimeIf(genLhs, genRhs, js.BooleanLiteral(false))(jstpe.BooleanType)
case _ => invalid()
}

case jstpe.IntType =>
(code: @switch) match {
case EQ => binaryOp(js.BinaryOp.Int_==)
case NE => binaryOp(js.BinaryOp.Int_!=)
case LT => binaryOp(js.BinaryOp.Int_<)
case LE => binaryOp(js.BinaryOp.Int_<=)
case GT => binaryOp(js.BinaryOp.Int_>)
case GE => binaryOp(js.BinaryOp.Int_>=)
case _ => invalid()
}

case _ =>
invalid()
}

case None => // if !isPrimitive
invalid()
}

case _ =>
invalid()
}
}

/** Gen JS code for a primitive JS call (to a method of a subclass of js.Any)
* This is the typed Scala.js to JS bridge feature. Basically it boils
* down to calling the method without name mangling. But other aspects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ trait JSDefinitions {
lazy val Runtime_dynamicImport = getMemberMethod(RuntimePackageModule, newTermName("dynamicImport"))

lazy val LinkingInfoModule = getRequiredModule("scala.scalajs.LinkingInfo")
lazy val LinkingInfo_linkTimeIf = getMemberMethod(LinkingInfoModule, newTermName("linkTimeIf"))
lazy val LinkingInfo_linkTimePropertyBoolean = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyBoolean"))
lazy val LinkingInfo_linkTimePropertyInt = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyInt"))
lazy val LinkingInfo_linkTimePropertyString = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyString"))

lazy val LinkTimePropertyAnnotation = getRequiredClass("scala.scalajs.annotation.linkTimeProperty")

lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk")
lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ abstract class JSPrimitives {
final val WRAP_AS_THROWABLE = JS_TRY_CATCH + 1 // js.special.wrapAsThrowable
final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable
final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger
final val LINKTIME_PROPERTY = DEBUGGER + 1 // LinkingInfo.linkTimePropertyXXX
final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf
final val LINKTIME_PROPERTY = LINKTIME_IF + 1 // LinkingInfo.linkTimePropertyXXX

final val LastJSPrimitiveCode = LINKTIME_PROPERTY

Expand Down Expand Up @@ -128,6 +129,7 @@ abstract class JSPrimitives {
addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE)
addPrimitive(Special_debugger, DEBUGGER)

addPrimitive(LinkingInfo_linkTimeIf, LINKTIME_IF)
addPrimitive(LinkingInfo_linkTimePropertyBoolean, LINKTIME_PROPERTY)
addPrimitive(LinkingInfo_linkTimePropertyInt, LINKTIME_PROPERTY)
addPrimitive(LinkingInfo_linkTimePropertyString, LINKTIME_PROPERTY)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Scala.js (https://www.scala-js.org/)
*
* Copyright EPFL.
*
* Licensed under Apache License 2.0
* (https://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package org.scalajs.nscplugin.test

import util._

import org.junit.Test
import org.junit.Assert._

// scalastyle:off line.size.limit

class LinkTimeIfTest extends TestHelpers {
override def preamble: String = "import scala.scalajs.LinkingInfo._"

private final val IllegalLinkTimeIfArgMessage = {
"Illegal expression in the condition of a linkTimeIf. " +
"Valid expressions are: boolean and int primitives; " +
"references to link-time properties; " +
"primitive operations on booleans; " +
"and comparisons on ints."
}

@Test
def linkTimeErrorInvalidOp(): Unit = {
"""
object A {
def foo =
linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { }
}
""" hasErrors
s"""
|newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { }
| ^
"""
}

@Test
def linkTimeErrorInvalidEntities(): Unit = {
"""
object A {
def foo(x: String) = {
val bar = 1
linkTimeIf(bar == 0) { } { }
}
}
""" hasErrors
s"""
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf(bar == 0) { } { }
| ^
"""

// String comparison is a `BinaryOp.===`, which is not allowed
"""
object A {
def foo(x: String) =
linkTimeIf("foo" == x) { } { }
}
""" hasErrors
s"""
|newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf("foo" == x) { } { }
| ^
"""

"""
object A {
def bar = true
def foo(x: String) =
linkTimeIf(bar || !bar) { } { }
}
""" hasErrors
s"""
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf(bar || !bar) { } { }
| ^
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf(bar || !bar) { } { }
| ^
"""
}

@Test
def linkTimeCondInvalidTree(): Unit = {
"""
object A {
def bar = true
def foo(x: String) =
linkTimeIf(if (bar) true else false) { } { }
}
""" hasErrors
s"""
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
| linkTimeIf(if (bar) true else false) { } { }
| ^
"""
}
}
7 changes: 7 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ object Hashers {
mixTree(elsep)
mixType(tree.tpe)

case LinkTimeIf(cond, thenp, elsep) =>
mixTag(TagLinkTimeIf)
mixTree(cond)
mixTree(thenp)
mixTree(elsep)
mixType(tree.tpe)

case While(cond, body) =>
mixTag(TagWhile)
mixTree(cond)
Expand Down
9 changes: 9 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ object Printers {
protected def printBlock(tree: Tree): Unit = {
val trees = tree match {
case Block(trees) => trees
case Skip() => Nil
case _ => tree :: Nil
}
printBlock(trees)
Expand Down Expand Up @@ -232,6 +233,14 @@ object Printers {
printBlock(elsep)
}

case LinkTimeIf(cond, thenp, elsep) =>
print("link-time-if (")
print(cond)
print(") ")
printBlock(thenp)
print(" else ")
printBlock(elsep)

case While(cond, body) =>
print("while (")
print(cond)
Expand Down
4 changes: 2 additions & 2 deletions ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import java.util.concurrent.ConcurrentHashMap
import scala.util.matching.Regex

object ScalaJSVersions extends VersionChecks(
current = "1.19.1-SNAPSHOT",
binaryEmitted = "1.19"
current = "1.20.0-SNAPSHOT",
binaryEmitted = "1.20-SNAPSHOT"
)

/** Helper class to allow for testing of logic. */
Expand Down
16 changes: 13 additions & 3 deletions ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ object Serializers {
writeTree(cond); writeTree(thenp); writeTree(elsep)
writeType(tree.tpe)

case LinkTimeIf(cond, thenp, elsep) =>
writeTagAndPos(TagLinkTimeIf)
writeTree(cond); writeTree(thenp); writeTree(elsep)
writeType(tree.tpe)

case While(cond, body) =>
writeTagAndPos(TagWhile)
writeTree(cond); writeTree(body)
Expand Down Expand Up @@ -1196,9 +1201,14 @@ object Serializers {

Assign(lhs.asInstanceOf[AssignLhs], rhs)

case TagReturn => Return(readTree(), readLabelName())
case TagIf => If(readTree(), readTree(), readTree())(readType())
case TagWhile => While(readTree(), readTree())
case TagReturn =>
Return(readTree(), readLabelName())
case TagIf =>
If(readTree(), readTree(), readTree())(readType())
case TagLinkTimeIf =>
LinkTimeIf(readTree(), readTree(), readTree())(readType())
case TagWhile =>
While(readTree(), readTree())

case TagDoWhile =>
if (!hacks.useBelow(13))
Expand Down
3 changes: 3 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Tags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ private[ir] object Tags {
final val TagNewLambda = TagApplyTypedClosure + 1
final val TagJSAwait = TagNewLambda + 1

// New in 1.20
final val TagLinkTimeIf = TagJSAwait + 1

// Tags for member defs

final val TagFieldDef = 1
Expand Down
3 changes: 3 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ object Transformers {
case If(cond, thenp, elsep) =>
If(transform(cond), transform(thenp), transform(elsep))(tree.tpe)

case LinkTimeIf(cond, thenp, elsep) =>
LinkTimeIf(transform(cond), transform(thenp), transform(elsep))(tree.tpe)

case While(cond, body) =>
While(transform(cond), transform(body))

Expand Down
5 changes: 5 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ object Traversers {
traverse(thenp)
traverse(elsep)

case LinkTimeIf(cond, thenp, elsep) =>
traverse(cond)
traverse(thenp)
traverse(elsep)

case While(cond, body) =>
traverse(cond)
traverse(body)
Expand Down
32 changes: 32 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,38 @@ object Trees {
sealed case class If(cond: Tree, thenp: Tree, elsep: Tree)(val tpe: Type)(
implicit val pos: Position) extends Tree

/** Link-time `if` expression.
*
* The `cond` must be a well-typed link-time tree of type `boolean`.
*
* A link-time tree is a `Tree` matching the following sub-grammar:
*
* {{{
* link-time-tree ::=
* BooleanLiteral
* | IntLiteral
* | StringLiteral
* | LinkTimeProperty
* | UnaryOp(link-time-unary-op, link-time-tree)
* | BinaryOp(link-time-binary-op, link-time-tree, link-time-tree)
* | LinkTimeIf(link-time-tree, link-time-tree, link-time-tree)
*
* link-time-unary-op ::=
* Boolean_!
*
* link-time-binary-op ::=
* Boolean_== | Boolean_!= | Boolean_| | Boolean_&
* | Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>=
* }}}
*
* Note: nested `LinkTimeIf` nodes in the `cond` are used to encode
* short-circuiting boolean `&&` and `||`, just like we do with regular
* `If` nodes.
*/
sealed case class LinkTimeIf(cond: Tree, thenp: Tree, elsep: Tree)(
val tpe: Type)(implicit val pos: Position)
extends Tree

sealed case class While(cond: Tree, body: Tree)(
implicit val pos: Position) extends Tree {
val tpe = cond match {
Expand Down
Loading