diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala index a1238e7433..564ebbb99c 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala @@ -283,7 +283,7 @@ private object BasicLinkerBackend { val jsCodeWriter = new ByteArrayWriter() val printer = new Printers.JSTreePrinter(jsCodeWriter) - printer.printTopLevelTree(tree) + printer.printStat(tree) new PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty) } @@ -321,7 +321,7 @@ private object BasicLinkerBackend { val smFragmentBuilder = new SourceMapWriter.FragmentBuilder() val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder) - printer.printTopLevelTree(tree) + printer.printStat(tree) smFragmentBuilder.complete() new PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result()) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala index 0701d2fd84..144672a471 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala @@ -683,12 +683,12 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { def genInstanceTests(className: ClassName, kind: ClassKind)( implicit moduleContext: ModuleContext, - globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[js.Tree] = { + globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[List[js.Tree]] = { for { single <- genSingleInstanceTests(className, kind) array <- genArrayInstanceTests(className) } yield { - js.Block(single ::: array) + single ::: array } } @@ -1028,16 +1028,16 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { case e: TopLevelMethodExportDef => genTopLevelMethodExportDef(e) case e: TopLevelFieldExportDef => - genTopLevelFieldExportDef(topLevelExport.owningClass, e) + genTopLevelFieldExportDef(topLevelExport.owningClass, e).map(_ :: Nil) } } - WithGlobals.list(exportsWithGlobals) + WithGlobals.flatten(exportsWithGlobals) } private def genTopLevelMethodExportDef(tree: TopLevelMethodExportDef)( implicit moduleContext: ModuleContext, - globalKnowledge: GlobalKnowledge): WithGlobals[js.Tree] = { + globalKnowledge: GlobalKnowledge): WithGlobals[List[js.Tree]] = { import TreeDSL._ val JSMethodDef(flags, StringLiteral(exportName), args, restParam, body) = @@ -1056,22 +1056,22 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { private def genConstValueExportDef(exportName: String, exportedValue: js.Tree)( - implicit pos: Position): WithGlobals[js.Tree] = { + implicit pos: Position): WithGlobals[List[js.Tree]] = { moduleKind match { case ModuleKind.NoModule => - genAssignToNoModuleExportVar(exportName, exportedValue) + genAssignToNoModuleExportVar(exportName, exportedValue).map(_ :: Nil) case ModuleKind.ESModule => val field = fileLevelVar(VarField.e, exportName) val let = js.Let(field.ident, mutable = true, Some(exportedValue)) val exportStat = js.Export((field.ident -> js.ExportName(exportName)) :: Nil) - WithGlobals(js.Block(let, exportStat)) + WithGlobals(List(let, exportStat)) case ModuleKind.CommonJSModule => globalRef("exports").map { exportsVarRef => js.Assign( genBracketSelect(exportsVarRef, js.StringLiteral(exportName)), - exportedValue) + exportedValue) :: Nil } } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index d5b23e4525..d89de7ae4d 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -595,7 +595,7 @@ final class Emitter(config: Emitter.Config) { */ if (classEmitter.needInstanceTests(linkedClass)(classCache)) { - main += extractWithGlobals(classTreeCache.instanceTests.getOrElseUpdate( + main ++= extractWithGlobals(classTreeCache.instanceTests.getOrElseUpdate( classEmitter.genInstanceTests(className, kind)(moduleContext, classCache, linkedClass.pos))) } @@ -1035,7 +1035,7 @@ object Emitter { private final class DesugaredClassCache { val privateJSFields = new OneTimeCache[WithGlobals[List[js.Tree]]] - val instanceTests = new OneTimeCache[WithGlobals[js.Tree]] + val instanceTests = new OneTimeCache[WithGlobals[List[js.Tree]]] val typeData = new OneTimeCache[WithGlobals[List[js.Tree]]] val setTypeData = new OneTimeCache[js.Tree] val moduleAccessor = new OneTimeCache[WithGlobals[List[js.Tree]]] diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala index 2df5acc9f9..9690946e69 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala @@ -44,6 +44,9 @@ object Printers { protected def println(): Unit = { out.write('\n') + } + + protected def printIndent(): Unit = { val indentArray = this.indentArray val indentMargin = this.indentMargin val bigEnoughIndentArray = @@ -61,30 +64,7 @@ object Printers { newIndentArray } - def printTopLevelTree(tree: Tree): Unit = { - tree match { - case Skip() => - // do not print anything - case tree: Block => - var rest = tree.stats - while (rest.nonEmpty) { - printTopLevelTree(rest.head) - rest = rest.tail - } - case _ => - printStat(tree) - if (shouldPrintSepAfterTree(tree)) - print(';') - println() - } - } - - protected def shouldPrintSepAfterTree(tree: Tree): Boolean = tree match { - case _:DocComment | _:FunctionDef | _:ClassDef => false - case _ => true - } - - protected def printRow(ts: List[Tree], start: Char, end: Char): Unit = { + private def printRow(ts: List[Tree], start: Char, end: Char): Unit = { print(start) var rest = ts while (rest.nonEmpty) { @@ -96,29 +76,26 @@ object Printers { print(end) } - protected def printBlock(tree: Tree): Unit = { + private def printBlock(tree: Tree): Unit = { print('{'); indent(); println() tree match { + case Skip() => + // do not print anything + case tree: Block => var rest = tree.stats while (rest.nonEmpty) { - val x = rest.head + printStat(rest.head) rest = rest.tail - printStat(x) - if (rest.nonEmpty) { - if (shouldPrintSepAfterTree(x)) - print(';') - println() - } } case _ => printStat(tree) } - undent(); println(); print('}') + undent(); printIndent(); print('}') } - protected def printSig(args: List[ParamDef], restParam: Option[ParamDef]): Unit = { + private def printSig(args: List[ParamDef], restParam: Option[ParamDef]): Unit = { print("(") var rem = args while (rem.nonEmpty) { @@ -136,16 +113,31 @@ object Printers { print(") ") } - protected def printArgs(args: List[Tree]): Unit = + private def printArgs(args: List[Tree]): Unit = printRow(args, '(', ')') - protected def printStat(tree: Tree): Unit = + /** Prints a stat including leading indent and trailing newline. */ + final def printStat(tree: Tree): Unit = { + printIndent() printTree(tree, isStat = true) + println() + } - protected def print(tree: Tree): Unit = + private def print(tree: Tree): Unit = printTree(tree, isStat = false) + /** Print the "meat" of a tree. + * + * Even if it is a stat: + * - No leading indent. + * - No trailing newline. + */ def printTree(tree: Tree, isStat: Boolean): Unit = { + def printSeparatorIfStat() = { + if (isStat) + print(';') + } + tree match { // Comments @@ -158,12 +150,12 @@ object Printers { } else { print("/** ") print(lines.head) - println() + println(); printIndent() var rest = lines.tail while (rest.nonEmpty) { print(" * ") print(rest.head) - println() + println(); printIndent() rest = rest.tail } print(" */") @@ -178,6 +170,8 @@ object Printers { print(" = ") print(rhs) } + // VarDef is an "expr" in a "For" / "ForIn" tree + printSeparatorIfStat() case Let(ident, mutable, optRhs) => print(if (mutable) "let " else "const ") @@ -186,6 +180,8 @@ object Printers { print(" = ") print(rhs) } + // Let is an "expr" in a "For" / "ForIn" tree + printSeparatorIfStat() case ParamDef(ident) => print(ident) @@ -210,10 +206,12 @@ object Printers { print(lhs) print(" = ") print(rhs) + printSeparatorIfStat() case Return(expr) => print("return ") print(expr) + print(';') case If(cond, thenp, elsep) => if (isStat) { @@ -306,19 +304,22 @@ object Printers { case Throw(expr) => print("throw ") print(expr) + print(';') case Break(label) => - if (label.isEmpty) print("break") + if (label.isEmpty) print("break;") else { print("break ") print(label.get) + print(';') } case Continue(label) => - if (label.isEmpty) print("continue") + if (label.isEmpty) print("continue;") else { print("continue ") print(label.get) + print(';') } case Switch(selector, cases, default) => @@ -331,7 +332,7 @@ object Printers { while (rest.nonEmpty) { val next = rest.head rest = rest.tail - println() + println(); printIndent() print("case ") print(next._1) print(':') @@ -344,17 +345,17 @@ object Printers { default match { case Skip() => case _ => - println() + println(); printIndent() print("default: ") printBlock(default) } undent() - println() + println(); printIndent() print('}') case Debugger() => - print("debugger") + print("debugger;") // Expressions @@ -375,6 +376,7 @@ object Printers { print(')') } printArgs(args) + printSeparatorIfStat() case DotSelect(qualifier, item) => qualifier match { @@ -387,27 +389,33 @@ object Printers { } print(".") print(item) + printSeparatorIfStat() case BracketSelect(qualifier, item) => print(qualifier) print('[') print(item) print(']') + printSeparatorIfStat() case Apply(fun, args) => print(fun) printArgs(args) + printSeparatorIfStat() case ImportCall(arg) => print("import(") print(arg) print(')') + printSeparatorIfStat() case NewTarget() => print("new.target") + printSeparatorIfStat() case ImportMeta() => print("import.meta") + printSeparatorIfStat() case Spread(items) => print("...") @@ -416,6 +424,7 @@ object Printers { case Delete(prop) => print("delete ") print(prop) + printSeparatorIfStat() case UnaryOp(op, lhs) => import ir.Trees.JSUnaryOp._ @@ -433,6 +442,7 @@ object Printers { } print(lhs) print(')') + printSeparatorIfStat() case IncDec(prefix, inc, arg) => val op = if (inc) "++" else "--" @@ -443,6 +453,7 @@ object Printers { if (!prefix) print(op) print(')') + printSeparatorIfStat() case BinaryOp(op, lhs, rhs) => import ir.Trees.JSBinaryOp._ @@ -482,13 +493,15 @@ object Printers { print(' ') print(rhs) print(')') + printSeparatorIfStat() case ArrayConstr(items) => printRow(items, '[', ']') + printSeparatorIfStat() case ObjectConstr(Nil) => if (isStat) - print("({})") // force expression position for the object literal + print("({});") // force expression position for the object literal else print("{}") @@ -502,30 +515,34 @@ object Printers { while (rest.nonEmpty) { val x = rest.head rest = rest.tail + printIndent() print(x._1) print(": ") print(x._2) if (rest.nonEmpty) { print(',') - println() } + println() } undent() - println() + printIndent() print('}') if (isStat) - print(')') + print(");") // Literals case Undefined() => print("(void 0)") + printSeparatorIfStat() case Null() => print("null") + printSeparatorIfStat() case BooleanLiteral(value) => print(if (value) "true" else "false") + printSeparatorIfStat() case IntLiteral(value) => if (value >= 0) { @@ -535,6 +552,7 @@ object Printers { print(value.toString) print(')') } + printSeparatorIfStat() case DoubleLiteral(value) => if (value == 0 && 1 / value < 0) { @@ -546,11 +564,13 @@ object Printers { print(value.toString) print(')') } + printSeparatorIfStat() case StringLiteral(value) => print('\"') printEscapeJS(value) print('\"') + printSeparatorIfStat() case BigIntLiteral(value) => if (value >= 0) { @@ -561,14 +581,17 @@ object Printers { print(value.toString) print("n)") } + printSeparatorIfStat() // Atomic expressions case VarRef(ident) => print(ident) + printSeparatorIfStat() case This() => print("this") + printSeparatorIfStat() case Function(arrow, args, restParam, body) => if (arrow) { @@ -595,6 +618,7 @@ object Printers { printBlock(body) print(')') } + printSeparatorIfStat() // Named function definition @@ -620,15 +644,13 @@ object Printers { print(" extends ") print(optParentClass.get) } - print(" {"); indent() + print(" {"); indent(); println() var rest = members while (rest.nonEmpty) { - println() - print(rest.head) - print(';') + printStat(rest.head) rest = rest.tail } - undent(); println(); print('}') + undent(); printIndent(); print('}') case MethodDef(static, name, params, restParam, body) => if (static) @@ -677,12 +699,14 @@ object Printers { } print(" } from ") print(from: Tree) + print(';') case ImportNamespace(binding, from) => print("import * as ") print(binding) print(" from ") print(from: Tree) + print(';') case Export(bindings) => print("export { ") @@ -699,7 +723,7 @@ object Printers { print(binding._2) rest = rest.tail } - print(" }") + print(" };") case ExportImport(bindings, from) => print("export { ") @@ -718,6 +742,7 @@ object Printers { } print(" } from ") print(from: Tree) + print(';') case _ => throw new IllegalArgumentException( @@ -741,7 +766,7 @@ object Printers { print("]") } - protected def print(exportName: ExportName): Unit = + private def print(exportName: ExportName): Unit = printEscapeJS(exportName.name) /** Prints an ASCII string -- use for syntax strings, not for user strings. */ @@ -782,6 +807,13 @@ object Printers { override protected def println(): Unit = { super.println() sourceMap.nextLine() + column = 0 + } + + override protected def printIndent(): Unit = { + assert(column == 0) + + super.printIndent() column = this.getIndentMargin() } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala index a64f546d68..9dcc074647 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala @@ -70,8 +70,8 @@ class LibrarySizeTest { ) testLinkedSizes( - expectedFastLinkSize = 150031, - expectedFullLinkSizeWithoutClosure = 130655, + expectedFastLinkSize = 150339, + expectedFullLinkSizeWithoutClosure = 130884, expectedFullLinkSizeWithClosure = 21394, classDefs, moduleInitializers = MainTestModuleInitializers diff --git a/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala new file mode 100644 index 0000000000..316f2c3907 --- /dev/null +++ b/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala @@ -0,0 +1,159 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.backend.javascript + +import scala.language.implicitConversions + +import java.nio.charset.StandardCharsets + +import org.junit.Test +import org.junit.Assert._ + +import org.scalajs.ir + +import Trees._ + +class PrintersTest { + + private implicit val pos: ir.Position = ir.Position.NoPosition + + private implicit def str2ident(name: String): Ident = + Ident(name, ir.OriginalName.NoOriginalName) + + private def assertPrintEquals(expected: String, tree: Tree): Unit = { + val out = new ByteArrayWriter + val printer = new Printers.JSTreePrinter(out) + printer.printStat(tree) + assertEquals(expected.stripMargin.trim + "\n", + new String(out.toByteArray(), StandardCharsets.UTF_8)) + } + + @Test def printFunctionDef(): Unit = { + assertPrintEquals( + """ + |function test() { + | const x = 2; + | return x; + |} + """, + FunctionDef("test", Nil, None, Block( + Let("x", mutable = false, Some(IntLiteral(2))), + Return(VarRef("x")))) + ) + + assertPrintEquals( + """ + |function test() { + |} + """, + FunctionDef("test", Nil, None, Skip()) + ) + } + + @Test def printClassDef(): Unit = { + assertPrintEquals( + """ + |class MyClass extends foo.Other { + |} + """, + ClassDef(Some("MyClass"), Some(DotSelect(VarRef("foo"), "Other")), Nil) + ) + + assertPrintEquals( + """ + |class MyClass { + | foo() { + | } + | get a() { + | return 1; + | } + | set a(x) { + | } + |} + """, + ClassDef(Some("MyClass"), None, List( + MethodDef(false, "foo", Nil, None, Skip()), + GetterDef(false, "a", Return(IntLiteral(1))), + SetterDef(false, "a", ParamDef("x"), Skip()) + )) + ) + } + + @Test def printDocComment(): Unit = { + assertPrintEquals( + """ + | /** test */ + """, + DocComment("test") + ) + } + + @Test def printFor(): Unit = { + assertPrintEquals( + """ + |for (let x = 1; (x < 15); x = (x + 1)) { + |} + """, + For(Let("x", true, Some(IntLiteral(1))), + BinaryOp(ir.Trees.JSBinaryOp.<, VarRef("x"), IntLiteral(15)), + Assign(VarRef("x"), BinaryOp(ir.Trees.JSBinaryOp.+, VarRef("x"), IntLiteral(1))), + Skip()) + ) + } + + @Test def printForIn(): Unit = { + assertPrintEquals( + """ + |for (var x in foo) { + |} + """, + ForIn(VarDef("x", None), VarRef("foo"), Skip()) + ) + } + + @Test def printIf(): Unit = { + assertPrintEquals( + """ + |if (false) { + | 1; + |} + """, + If(BooleanLiteral(false), IntLiteral(1), Skip()) + ) + + assertPrintEquals( + """ + |if (false) { + | 1; + |} else { + | 2; + |} + """, + If(BooleanLiteral(false), IntLiteral(1), IntLiteral(2)) + ) + + assertPrintEquals( + """ + |if (false) { + | 1; + |} else if (true) { + | 2; + |} else { + | 3; + |} + """, + If(BooleanLiteral(false), IntLiteral(1), + If(BooleanLiteral(true), IntLiteral(2), IntLiteral(3))) + ) + } +} diff --git a/project/Build.scala b/project/Build.scala index 93209c3ddb..758975e8f0 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1967,15 +1967,15 @@ object Build { scalaVersion.value match { case `default212Version` => Some(ExpectedSizes( - fastLink = 772000 to 773000, + fastLink = 770000 to 771000, fullLink = 145000 to 146000, - fastLinkGz = 91000 to 92000, + fastLinkGz = 90000 to 91000, fullLinkGz = 35000 to 36000, )) case `default213Version` => Some(ExpectedSizes( - fastLink = 480000 to 481000, + fastLink = 479000 to 480000, fullLink = 102000 to 103000, fastLinkGz = 62000 to 63000, fullLinkGz = 27000 to 28000,