From 86c18be06bfff8c0245e1ab18b404bcb238a56ba Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Sun, 15 Oct 2023 20:05:42 +0200 Subject: [PATCH 1/3] Fuse emitting and printing of trees in the backend This allows us to use the Emitter's powerful caching mechanism to directly cache printed trees (as byte buffers) and not cache JavaScript trees anymore at all. This reduces in-between run memory usage on the test suite from 1.12 GB (not GiB) to 1.00 GB on my machine (roughly 10%). Runtime performance (both batch and incremental) is unaffected. It is worth pointing out, that due to how the Emitter caches trees, classes that end up being ES6 classes is performed will be held twice in memory (once the individual methods, once the entire class). On the test suite, this is the case for 710 cases out of 6538. --- .../closure/ClosureLinkerBackend.scala | 9 +- .../linker/backend/BasicLinkerBackend.scala | 169 ++++++-------- .../linker/backend/emitter/ClassEmitter.scala | 6 +- .../linker/backend/emitter/CoreJSLib.scala | 22 +- .../linker/backend/emitter/Emitter.scala | 213 +++++++++++------- .../linker/backend/javascript/Printers.scala | 30 ++- .../linker/backend/javascript/Trees.scala | 18 ++ .../linker/BasicLinkerBackendTest.scala | 30 +-- .../org/scalajs/linker/EmitterTest.scala | 94 ++++++++ .../backend/javascript/PrintersTest.scala | 24 +- 10 files changed, 383 insertions(+), 232 deletions(-) diff --git a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala index 003e873773..1f532767e2 100644 --- a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala +++ b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala @@ -60,7 +60,7 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config) .withTrackAllGlobalRefs(true) .withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id)) - new Emitter(emitterConfig) + new Emitter(emitterConfig, ClosureLinkerBackend.PostTransformer) } val symbolRequirements: SymbolRequirement = emitter.symbolRequirements @@ -295,4 +295,11 @@ private object ClosureLinkerBackend { Function.prototype.apply; var NaN = 0.0/0.0, Infinity = 1.0/0.0, undefined = void 0; """ + + private object PostTransformer extends Emitter.PostTransformer[js.Tree] { + // Do not apply ClosureAstTransformer eagerly: + // The ASTs used by closure are highly mutable, so re-using them is non-trivial. + // Since closure is slow anyways, we haven't built the optimization. + def transformStats(trees: List[js.Tree], indent: Int): List[js.Tree] = trees + } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala index 564ebbb99c..4faef57c0a 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala @@ -17,6 +17,8 @@ import scala.concurrent._ import java.nio.ByteBuffer import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicInteger + import org.scalajs.logging.Logger import org.scalajs.linker.interface.{IRFile, OutputDirectory, Report} @@ -36,12 +38,19 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) import BasicLinkerBackend._ + private[this] var totalModules = 0 + private[this] val rewrittenModules = new AtomicInteger(0) + private[this] val emitter = { val emitterConfig = Emitter.Config(config.commonConfig.coreSpec) .withJSHeader(config.jsHeader) .withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id)) - new Emitter(emitterConfig) + val postTransformer = + if (config.sourceMap) PostTransformerWithSourceMap + else PostTransformerWithoutSourceMap + + new Emitter(emitterConfig, postTransformer) } val symbolRequirements: SymbolRequirement = emitter.symbolRequirements @@ -61,6 +70,11 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) implicit ec: ExecutionContext): Future[Report] = { verifyModuleSet(moduleSet) + // Reset stats. + + totalModules = moduleSet.modules.size + rewrittenModules.set(0) + val emitterResult = logger.time("Emitter") { emitter.emit(moduleSet, logger) } @@ -68,24 +82,25 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) val skipContentCheck = !isFirstRun isFirstRun = false - printedModuleSetCache.startRun(moduleSet) val allChanged = printedModuleSetCache.updateGlobal(emitterResult.header, emitterResult.footer) val writer = new OutputWriter(output, config, skipContentCheck) { protected def writeModuleWithoutSourceMap(moduleID: ModuleID, force: Boolean): Option[ByteBuffer] = { val cache = printedModuleSetCache.getModuleCache(moduleID) - val changed = cache.update(emitterResult.body(moduleID)) + val printedTrees = emitterResult.body(moduleID) + + val changed = cache.update(printedTrees) if (force || changed || allChanged) { - printedModuleSetCache.incRewrittenModules() + rewrittenModules.incrementAndGet() val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize())) jsFileWriter.write(printedModuleSetCache.headerBytes) jsFileWriter.writeASCIIString("'use strict';\n") - for (printedTree <- cache.printedTrees) + for (printedTree <- printedTrees) jsFileWriter.write(printedTree.jsCode) jsFileWriter.write(printedModuleSetCache.footerBytes) @@ -99,10 +114,12 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) protected def writeModuleWithSourceMap(moduleID: ModuleID, force: Boolean): Option[(ByteBuffer, ByteBuffer)] = { val cache = printedModuleSetCache.getModuleCache(moduleID) - val changed = cache.update(emitterResult.body(moduleID)) + val printedTrees = emitterResult.body(moduleID) + + val changed = cache.update(printedTrees) if (force || changed || allChanged) { - printedModuleSetCache.incRewrittenModules() + rewrittenModules.incrementAndGet() val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize())) val sourceMapWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalSourceMapSize())) @@ -120,7 +137,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) jsFileWriter.writeASCIIString("'use strict';\n") smWriter.nextLine() - for (printedTree <- cache.printedTrees) { + for (printedTree <- printedTrees) { jsFileWriter.write(printedTree.jsCode) smWriter.insertFragment(printedTree.sourceMapFragment) } @@ -145,9 +162,15 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) writer.write(moduleSet) }.andThen { case _ => printedModuleSetCache.cleanAfterRun() - printedModuleSetCache.logStats(logger) + logStats(logger) } } + + private def logStats(logger: Logger): Unit = { + // Message extracted in BasicLinkerBackendTest + logger.debug( + s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}") + } } private object BasicLinkerBackend { @@ -161,20 +184,6 @@ private object BasicLinkerBackend { private val modules = new java.util.concurrent.ConcurrentHashMap[ModuleID, PrintedModuleCache] - private var totalModules = 0 - private val rewrittenModules = new java.util.concurrent.atomic.AtomicInteger(0) - - private var totalTopLevelTrees = 0 - private var recomputedTopLevelTrees = 0 - - def startRun(moduleSet: ModuleSet): Unit = { - totalModules = moduleSet.modules.size - rewrittenModules.set(0) - - totalTopLevelTrees = 0 - recomputedTopLevelTrees = 0 - } - def updateGlobal(header: String, footer: String): Boolean = { if (header == lastHeader && footer == lastFooter) { false @@ -193,61 +202,32 @@ private object BasicLinkerBackend { def headerNewLineCount: Int = _headerNewLineCountCache def getModuleCache(moduleID: ModuleID): PrintedModuleCache = { - val result = modules.computeIfAbsent(moduleID, { _ => - if (withSourceMaps) new PrintedModuleCacheWithSourceMaps - else new PrintedModuleCache - }) - + val result = modules.computeIfAbsent(moduleID, _ => new PrintedModuleCache) result.startRun() result } - def incRewrittenModules(): Unit = - rewrittenModules.incrementAndGet() - def cleanAfterRun(): Unit = { val iter = modules.entrySet().iterator() while (iter.hasNext()) { val moduleCache = iter.next().getValue() - if (moduleCache.cleanAfterRun()) { - totalTopLevelTrees += moduleCache.getTotalTopLevelTrees - recomputedTopLevelTrees += moduleCache.getRecomputedTopLevelTrees - } else { + if (!moduleCache.cleanAfterRun()) { iter.remove() } } } - - def logStats(logger: Logger): Unit = { - /* These messages are extracted in BasicLinkerBackendTest to assert that - * we do not invalidate anything in a no-op second run. - */ - logger.debug( - s"BasicBackend: total top-level trees: $totalTopLevelTrees; re-computed: $recomputedTopLevelTrees") - logger.debug( - s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}") - } - } - - private final class PrintedTree(val jsCode: Array[Byte], val sourceMapFragment: SourceMapWriter.Fragment) { - var cachedUsed: Boolean = false } private sealed class PrintedModuleCache { private var cacheUsed = false private var changed = false - private var lastJSTrees: List[js.Tree] = Nil - private var printedTreesCache: List[PrintedTree] = Nil - private val cache = new java.util.IdentityHashMap[js.Tree, PrintedTree] + private var lastPrintedTrees: List[js.PrintedTree] = Nil private var previousFinalJSFileSize: Int = 0 private var previousFinalSourceMapSize: Int = 0 - private var recomputedTopLevelTrees = 0 - def startRun(): Unit = { cacheUsed = true - recomputedTopLevelTrees = 0 } def getPreviousFinalJSFileSize(): Int = previousFinalJSFileSize @@ -259,72 +239,51 @@ private object BasicLinkerBackend { previousFinalSourceMapSize = finalSourceMapSize } - def update(newJSTrees: List[js.Tree]): Boolean = { - val changed = !newJSTrees.corresponds(lastJSTrees)(_ eq _) + def update(newPrintedTrees: List[js.PrintedTree]): Boolean = { + val changed = !newPrintedTrees.corresponds(lastPrintedTrees)(_ eq _) this.changed = changed if (changed) { - printedTreesCache = newJSTrees.map(getOrComputePrintedTree(_)) - lastJSTrees = newJSTrees + lastPrintedTrees = newPrintedTrees } changed } - private def getOrComputePrintedTree(tree: js.Tree): PrintedTree = { - val result = cache.computeIfAbsent(tree, { (tree: js.Tree) => - recomputedTopLevelTrees += 1 - computePrintedTree(tree) - }) - - result.cachedUsed = true - result - } - - protected def computePrintedTree(tree: js.Tree): PrintedTree = { - val jsCodeWriter = new ByteArrayWriter() - val printer = new Printers.JSTreePrinter(jsCodeWriter) - - printer.printStat(tree) - - new PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty) + def cleanAfterRun(): Boolean = { + val wasUsed = cacheUsed + cacheUsed = false + wasUsed } + } - def printedTrees: List[PrintedTree] = printedTreesCache + private object PostTransformerWithoutSourceMap extends Emitter.PostTransformer[js.PrintedTree] { + def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = { + if (trees.isEmpty) { + Nil // Fast path + } else { + val jsCodeWriter = new ByteArrayWriter() + val printer = new Printers.JSTreePrinter(jsCodeWriter, indent) - def cleanAfterRun(): Boolean = { - if (cacheUsed) { - cacheUsed = false - - if (changed) { - val iter = cache.entrySet().iterator() - while (iter.hasNext()) { - val printedTree = iter.next().getValue() - if (printedTree.cachedUsed) - printedTree.cachedUsed = false - else - iter.remove() - } - } + trees.map(printer.printStat(_)) - true - } else { - false + js.PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty) :: Nil } } - - def getTotalTopLevelTrees: Int = lastJSTrees.size - def getRecomputedTopLevelTrees: Int = recomputedTopLevelTrees } - private final class PrintedModuleCacheWithSourceMaps extends PrintedModuleCache { - override protected def computePrintedTree(tree: js.Tree): PrintedTree = { - val jsCodeWriter = new ByteArrayWriter() - val smFragmentBuilder = new SourceMapWriter.FragmentBuilder() - val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder) + private object PostTransformerWithSourceMap extends Emitter.PostTransformer[js.PrintedTree] { + def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = { + if (trees.isEmpty) { + Nil // Fast path + } else { + val jsCodeWriter = new ByteArrayWriter() + val smFragmentBuilder = new SourceMapWriter.FragmentBuilder() + val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder, indent) - printer.printStat(tree) - smFragmentBuilder.complete() + trees.map(printer.printStat(_)) + smFragmentBuilder.complete() - new PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result()) + js.PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result()) :: Nil + } } } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala index 424b962989..aade6c4b8a 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala @@ -45,7 +45,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { def buildClass(className: ClassName, isJSClass: Boolean, jsClassCaptures: Option[List[ParamDef]], hasClassInitializer: Boolean, - superClass: Option[ClassIdent], storeJSSuperClass: Option[js.Tree], useESClass: Boolean, + superClass: Option[ClassIdent], storeJSSuperClass: List[js.Tree], useESClass: Boolean, members: List[js.Tree])( implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[List[js.Tree]] = { @@ -75,7 +75,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { val createClassValueVar = genEmptyMutableLet(classValueIdent) val entireClassDefWithGlobals = if (useESClass) { - genJSSuperCtor(superClass, storeJSSuperClass.isDefined).map { jsSuperClass => + genJSSuperCtor(superClass, storeJSSuperClass.nonEmpty).map { jsSuperClass => List(classValueVar := js.ClassDef(Some(classValueIdent), Some(jsSuperClass), members)) } } else { @@ -86,7 +86,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { entireClassDef <- entireClassDefWithGlobals createStaticFields <- genCreateStaticFieldsOfJSClass(className) } yield { - storeJSSuperClass.toList ::: entireClassDef ::: createStaticFields + storeJSSuperClass ::: entireClassDef ::: createStaticFields } jsClassCaptures.fold { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala index 97330e7ccf..290bb7f362 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala @@ -32,9 +32,9 @@ import PolyfillableBuiltin._ private[emitter] object CoreJSLib { - def build(sjsGen: SJSGen, moduleContext: ModuleContext, - globalKnowledge: GlobalKnowledge): WithGlobals[Lib] = { - new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build() + def build[E](sjsGen: SJSGen, postTransform: List[Tree] => E, moduleContext: ModuleContext, + globalKnowledge: GlobalKnowledge): WithGlobals[Lib[E]] = { + new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build(postTransform) } /** A fully built CoreJSLib @@ -52,10 +52,10 @@ private[emitter] object CoreJSLib { * @param initialization Things that depend on Scala.js generated classes. * These must have class definitions (but not static fields) available. */ - final class Lib private[CoreJSLib] ( - val preObjectDefinitions: List[Tree], - val postObjectDefinitions: List[Tree], - val initialization: List[Tree]) + final class Lib[E] private[CoreJSLib] ( + val preObjectDefinitions: E, + val postObjectDefinitions: E, + val initialization: E) private class CoreJSLibBuilder(sjsGen: SJSGen)( implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge) { @@ -115,9 +115,11 @@ private[emitter] object CoreJSLib { private val specializedArrayTypeRefs: List[NonArrayTypeRef] = ClassRef(ObjectClass) :: orderedPrimRefsWithoutVoid - def build(): WithGlobals[Lib] = { - val lib = new Lib(buildPreObjectDefinitions(), - buildPostObjectDefinitions(), buildInitializations()) + def build[E](postTransform: List[Tree] => E): WithGlobals[Lib[E]] = { + val lib = new Lib( + postTransform(buildPreObjectDefinitions()), + postTransform(buildPostObjectDefinitions()), + postTransform(buildInitializations())) WithGlobals(lib, trackedGlobalRefs) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index a5191cdf8c..6035764aca 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -33,7 +33,8 @@ import EmitterNames._ import GlobalRefUtils._ /** Emits a desugared JS tree to a builder */ -final class Emitter(config: Emitter.Config) { +final class Emitter[E >: Null <: js.Tree]( + config: Emitter.Config, postTransformer: Emitter.PostTransformer[E]) { import Emitter._ import config._ @@ -71,13 +72,16 @@ final class Emitter(config: Emitter.Config) { private[this] var statsClassesInvalidated: Int = 0 private[this] var statsMethodsReused: Int = 0 private[this] var statsMethodsInvalidated: Int = 0 + private[this] var statsPostTransforms: Int = 0 + private[this] var statsNestedPostTransforms: Int = 0 + private[this] var statsNestedPostTransformsAvoided: Int = 0 val symbolRequirements: SymbolRequirement = Emitter.symbolRequirements(config) val injectedIRFiles: Seq[IRFile] = PrivateLibHolder.files - def emit(moduleSet: ModuleSet, logger: Logger): Result = { + def emit(moduleSet: ModuleSet, logger: Logger): Result[E] = { val WithGlobals(body, globalRefs) = emitInternal(moduleSet, logger) moduleKind match { @@ -108,12 +112,15 @@ final class Emitter(config: Emitter.Config) { } private def emitInternal(moduleSet: ModuleSet, - logger: Logger): WithGlobals[Map[ModuleID, List[js.Tree]]] = { + logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = { // Reset caching stats. statsClassesReused = 0 statsClassesInvalidated = 0 statsMethodsReused = 0 statsMethodsInvalidated = 0 + statsPostTransforms = 0 + statsNestedPostTransforms = 0 + statsNestedPostTransformsAvoided = 0 // Update GlobalKnowledge. val invalidateAll = knowledgeGuardian.update(moduleSet) @@ -128,13 +135,17 @@ final class Emitter(config: Emitter.Config) { try { emitAvoidGlobalClash(moduleSet, logger, secondAttempt = false) } finally { - // Report caching stats. + // Report caching stats (extracted in EmitterTest). logger.debug( s"Emitter: Class tree cache stats: reused: $statsClassesReused -- "+ s"invalidated: $statsClassesInvalidated") logger.debug( s"Emitter: Method tree cache stats: reused: $statsMethodsReused -- "+ s"invalidated: $statsMethodsInvalidated") + logger.debug( + s"Emitter: Post transforms: total: $statsPostTransforms -- " + + s"nested: $statsNestedPostTransforms -- " + + s"nested avoided: $statsNestedPostTransformsAvoided") // Inform caches about run completion. state.moduleCaches.filterInPlace((_, c) => c.cleanAfterRun()) @@ -142,6 +153,14 @@ final class Emitter(config: Emitter.Config) { } } + private def postTransform(trees: List[js.Tree], indent: Int): List[E] = { + statsPostTransforms += 1 + postTransformer.transformStats(trees, indent) + } + + private def postTransform(tree: js.Tree, indent: Int): List[E] = + postTransform(tree :: Nil, indent) + /** Emits all JavaScript code avoiding clashes with global refs. * * If, at the end of the process, the set of accessed dangerous globals has @@ -150,7 +169,7 @@ final class Emitter(config: Emitter.Config) { */ @tailrec private def emitAvoidGlobalClash(moduleSet: ModuleSet, - logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, List[js.Tree]]] = { + logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, List[E]]] = { val result = emitOnce(moduleSet, logger) val mentionedDangerousGlobalRefs = @@ -175,7 +194,7 @@ final class Emitter(config: Emitter.Config) { } private def emitOnce(moduleSet: ModuleSet, - logger: Logger): WithGlobals[Map[ModuleID, List[js.Tree]]] = { + logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = { // Genreate classes first so we can measure time separately. val generatedClasses = logger.time("Emitter: Generate Classes") { moduleSet.modules.map { module => @@ -200,7 +219,7 @@ final class Emitter(config: Emitter.Config) { val moduleImports = extractWithGlobals { moduleCache.getOrComputeImports(module.externalDependencies, module.internalDependencies) { - genModuleImports(module) + genModuleImports(module).map(postTransform(_, 0)) } } @@ -210,7 +229,7 @@ final class Emitter(config: Emitter.Config) { */ moduleCache.getOrComputeTopLevelExports(module.topLevelExports) { classEmitter.genTopLevelExports(module.topLevelExports)( - moduleContext, moduleCache) + moduleContext, moduleCache).map(postTransform(_, 0)) } } @@ -220,7 +239,7 @@ final class Emitter(config: Emitter.Config) { WithGlobals.list(initializers.map { initializer => classEmitter.genModuleInitializer(initializer)( moduleContext, moduleCache) - }) + }).map(postTransform(_, 0)) } } @@ -241,7 +260,7 @@ final class Emitter(config: Emitter.Config) { * requires consistency between the Analyzer and the Emitter. As such, * it is crucial that we verify it. */ - val defTrees: List[js.Tree] = ( + val defTrees: List[E] = ( /* The definitions of the CoreJSLib that come before the definition * of `j.l.Object`. They depend on nothing else. */ @@ -357,7 +376,7 @@ final class Emitter(config: Emitter.Config) { } private def genClass(linkedClass: LinkedClass, - moduleContext: ModuleContext): GeneratedClass = { + moduleContext: ModuleContext): GeneratedClass[E] = { val className = linkedClass.className val classCache = classCaches.getOrElseUpdate( @@ -379,7 +398,7 @@ final class Emitter(config: Emitter.Config) { // Main part - val main = List.newBuilder[js.Tree] + val main = List.newBuilder[E] val (linkedInlineableInit, linkedMethods) = classEmitter.extractInlineableInit(linkedClass)(classCache) @@ -388,7 +407,7 @@ final class Emitter(config: Emitter.Config) { if (kind.isJSClass) { val fieldDefs = classTreeCache.privateJSFields.getOrElseUpdate { classEmitter.genCreatePrivateJSFieldDefsOfJSClass(className)( - moduleContext, classCache) + moduleContext, classCache).map(postTransform(_, 0)) } main ++= extractWithGlobals(fieldDefs) } @@ -407,8 +426,10 @@ final class Emitter(config: Emitter.Config) { val methodCache = classCache.getStaticLikeMethodCache(namespace, methodDef.methodName) - main ++= extractWithGlobals(methodCache.getOrElseUpdate(methodDef.version, - classEmitter.genStaticLikeMethod(className, methodDef)(moduleContext, methodCache))) + main ++= extractWithGlobals(methodCache.getOrElseUpdate(methodDef.version, { + classEmitter.genStaticLikeMethod(className, methodDef)(moduleContext, methodCache) + .map(postTransform(_, 0)) + })) } } @@ -447,11 +468,21 @@ final class Emitter(config: Emitter.Config) { (isJSClass || linkedClass.ancestors.contains(ThrowableClass)) } + val memberIndent = { + (if (isJSClass) 1 else 0) + // accessor function + (if (useESClass) 1 else 0) // nesting from class + } + val hasJSSuperClass = linkedClass.jsSuperClass.isDefined - val storeJSSuperClass = linkedClass.jsSuperClass.map { jsSuperClass => - extractWithGlobals(classTreeCache.storeJSSuperClass.getOrElseUpdate( - classEmitter.genStoreJSSuperClass(jsSuperClass)(moduleContext, classCache, linkedClass.pos))) + val storeJSSuperClass = if (hasJSSuperClass) { + extractWithGlobals(classTreeCache.storeJSSuperClass.getOrElseUpdate({ + val jsSuperClass = linkedClass.jsSuperClass.get + classEmitter.genStoreJSSuperClass(jsSuperClass)(moduleContext, classCache, linkedClass.pos) + .map(postTransform(_, 1)) + })) + } else { + Nil } // JS constructor @@ -478,7 +509,7 @@ final class Emitter(config: Emitter.Config) { hasJSSuperClass, // invalidated by class version useESClass, // invalidated by class version jsConstructorDef // part of ctor version - )(moduleContext, ctorCache, linkedClass.pos)) + )(moduleContext, ctorCache, linkedClass.pos).map(postTransform(_, memberIndent))) } else { val ctorVersion = linkedInlineableInit.fold { Version.combine(linkedClass.version) @@ -492,7 +523,7 @@ final class Emitter(config: Emitter.Config) { linkedClass.superClass, // invalidated by class version useESClass, // invalidated by class version, linkedInlineableInit // part of ctor version - )(moduleContext, ctorCache, linkedClass.pos)) + )(moduleContext, ctorCache, linkedClass.pos).map(postTransform(_, memberIndent))) } } @@ -546,7 +577,7 @@ final class Emitter(config: Emitter.Config) { isJSClass, // invalidated by isJSClassVersion useESClass, // invalidated by isJSClassVersion method // invalidated by method.version - )(moduleContext, methodCache)) + )(moduleContext, methodCache).map(postTransform(_, memberIndent))) } // Exported Members @@ -561,7 +592,7 @@ final class Emitter(config: Emitter.Config) { isJSClass, // invalidated by isJSClassVersion useESClass, // invalidated by isJSClassVersion member // invalidated by version - )(moduleContext, memberCache)) + )(moduleContext, memberCache).map(postTransform(_, memberIndent))) } val hasClassInitializer: Boolean = { @@ -578,8 +609,9 @@ final class Emitter(config: Emitter.Config) { memberMethodsWithGlobals, exportedMembersWithGlobals, { for { ctor <- ctorWithGlobals - memberMethods <- WithGlobals.list(memberMethodsWithGlobals) - exportedMembers <- WithGlobals.list(exportedMembersWithGlobals) + memberMethods <- WithGlobals.flatten(memberMethodsWithGlobals) + exportedMembers <- WithGlobals.flatten(exportedMembersWithGlobals) + allMembers = ctor ::: memberMethods ::: exportedMembers clazz <- classEmitter.buildClass( className, // invalidated by overall class cache (part of ancestors) isJSClass, // invalidated by class version @@ -588,10 +620,17 @@ final class Emitter(config: Emitter.Config) { linkedClass.superClass, // invalidated by class version storeJSSuperClass, // invalidated by class version useESClass, // invalidated by class version (depends on kind, config and ancestry only) - ctor ::: memberMethods ::: exportedMembers.flatten // all 3 invalidated directly + allMembers // invalidated directly )(moduleContext, fullClassCache, linkedClass.pos) // pos invalidated by class version } yield { - clazz + // Avoid a nested post transform if we just got the original members back. + if (clazz eq allMembers) { + statsNestedPostTransformsAvoided += 1 + allMembers + } else { + statsNestedPostTransforms += 1 + postTransform(clazz, 0) + } } }) } @@ -614,8 +653,10 @@ final class Emitter(config: Emitter.Config) { */ if (classEmitter.needInstanceTests(linkedClass)(classCache)) { - main ++= extractWithGlobals(classTreeCache.instanceTests.getOrElseUpdate( - classEmitter.genInstanceTests(className, kind)(moduleContext, classCache, linkedClass.pos))) + main ++= extractWithGlobals(classTreeCache.instanceTests.getOrElseUpdate({ + classEmitter.genInstanceTests(className, kind)(moduleContext, classCache, linkedClass.pos) + .map(postTransform(_, 0)) + })) } if (linkedClass.hasRuntimeTypeInfo) { @@ -626,18 +667,22 @@ final class Emitter(config: Emitter.Config) { linkedClass.superClass, // invalidated by class version linkedClass.ancestors, // invalidated by overall class cache (identity) linkedClass.jsNativeLoadSpec // invalidated by class version - )(moduleContext, classCache, linkedClass.pos))) + )(moduleContext, classCache, linkedClass.pos).map(postTransform(_, 0)))) } if (linkedClass.hasInstances && kind.isClass && linkedClass.hasRuntimeTypeInfo) { - main += classTreeCache.setTypeData.getOrElseUpdate( - classEmitter.genSetTypeData(className)(moduleContext, classCache, linkedClass.pos)) + main ++= classTreeCache.setTypeData.getOrElseUpdate({ + val tree = classEmitter.genSetTypeData(className)(moduleContext, classCache, linkedClass.pos) + postTransform(tree, 0) + }) } } if (linkedClass.kind.hasModuleAccessor && linkedClass.hasInstances) { - main ++= extractWithGlobals(classTreeCache.moduleAccessor.getOrElseUpdate( - classEmitter.genModuleAccessor(className, isJSClass)(moduleContext, classCache, linkedClass.pos))) + main ++= extractWithGlobals(classTreeCache.moduleAccessor.getOrElseUpdate({ + classEmitter.genModuleAccessor(className, isJSClass)(moduleContext, classCache, linkedClass.pos) + .map(postTransform(_, 0)) + })) } // Static fields @@ -645,15 +690,19 @@ final class Emitter(config: Emitter.Config) { val staticFields = if (linkedClass.kind.isJSType) { Nil } else { - extractWithGlobals(classTreeCache.staticFields.getOrElseUpdate( - classEmitter.genCreateStaticFieldsOfScalaClass(className)(moduleContext, classCache))) + extractWithGlobals(classTreeCache.staticFields.getOrElseUpdate({ + classEmitter.genCreateStaticFieldsOfScalaClass(className)(moduleContext, classCache) + .map(postTransform(_, 0)) + })) } // Static initialization val staticInitialization = if (classEmitter.needStaticInitialization(linkedClass)) { - classTreeCache.staticInitialization.getOrElseUpdate( - classEmitter.genStaticInitialization(className)(moduleContext, classCache, linkedClass.pos)) + classTreeCache.staticInitialization.getOrElseUpdate({ + val tree = classEmitter.genStaticInitialization(className)(moduleContext, classCache, linkedClass.pos) + postTransform(tree, 0) + }) } else { Nil } @@ -674,14 +723,14 @@ final class Emitter(config: Emitter.Config) { private final class ModuleCache extends knowledgeGuardian.KnowledgeAccessor { private[this] var _cacheUsed: Boolean = false - private[this] var _importsCache: WithGlobals[List[js.Tree]] = WithGlobals.nil + private[this] var _importsCache: WithGlobals[List[E]] = WithGlobals.nil private[this] var _lastExternalDependencies: Set[String] = Set.empty private[this] var _lastInternalDependencies: Set[ModuleID] = Set.empty - private[this] var _topLevelExportsCache: WithGlobals[List[js.Tree]] = WithGlobals.nil + private[this] var _topLevelExportsCache: WithGlobals[List[E]] = WithGlobals.nil private[this] var _lastTopLevelExports: List[LinkedTopLevelExport] = Nil - private[this] var _initializersCache: WithGlobals[List[js.Tree]] = WithGlobals.nil + private[this] var _initializersCache: WithGlobals[List[E]] = WithGlobals.nil private[this] var _lastInitializers: List[ModuleInitializer.Initializer] = Nil override def invalidate(): Unit = { @@ -702,7 +751,7 @@ final class Emitter(config: Emitter.Config) { } def getOrComputeImports(externalDependencies: Set[String], internalDependencies: Set[ModuleID])( - compute: => WithGlobals[List[js.Tree]]): WithGlobals[List[js.Tree]] = { + compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { _cacheUsed = true @@ -715,7 +764,7 @@ final class Emitter(config: Emitter.Config) { } def getOrComputeTopLevelExports(topLevelExports: List[LinkedTopLevelExport])( - compute: => WithGlobals[List[js.Tree]]): WithGlobals[List[js.Tree]] = { + compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { _cacheUsed = true @@ -754,7 +803,7 @@ final class Emitter(config: Emitter.Config) { } def getOrComputeInitializers(initializers: List[ModuleInitializer.Initializer])( - compute: => WithGlobals[List[js.Tree]]): WithGlobals[List[js.Tree]] = { + compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { _cacheUsed = true @@ -773,20 +822,20 @@ final class Emitter(config: Emitter.Config) { } private final class ClassCache extends knowledgeGuardian.KnowledgeAccessor { - private[this] var _cache: DesugaredClassCache = null + private[this] var _cache: DesugaredClassCache[List[E]] = null private[this] var _lastVersion: Version = Version.Unversioned private[this] var _cacheUsed = false private[this] val _methodCaches = - Array.fill(MemberNamespace.Count)(mutable.Map.empty[MethodName, MethodCache[List[js.Tree]]]) + Array.fill(MemberNamespace.Count)(mutable.Map.empty[MethodName, MethodCache[List[E]]]) private[this] val _memberMethodCache = - mutable.Map.empty[MethodName, MethodCache[js.Tree]] + mutable.Map.empty[MethodName, MethodCache[List[E]]] - private[this] var _constructorCache: Option[MethodCache[List[js.Tree]]] = None + private[this] var _constructorCache: Option[MethodCache[List[E]]] = None private[this] val _exportedMembersCache = - mutable.Map.empty[Int, MethodCache[List[js.Tree]]] + mutable.Map.empty[Int, MethodCache[List[E]]] private[this] var _fullClassCache: Option[FullClassCache] = None @@ -807,12 +856,12 @@ final class Emitter(config: Emitter.Config) { _fullClassCache.foreach(_.startRun()) } - def getCache(version: Version): DesugaredClassCache = { + def getCache(version: Version): DesugaredClassCache[List[E]] = { if (_cache == null || !_lastVersion.sameVersion(version)) { invalidate() statsClassesInvalidated += 1 _lastVersion = version - _cache = new DesugaredClassCache + _cache = new DesugaredClassCache[List[E]] } else { statsClassesReused += 1 } @@ -821,25 +870,25 @@ final class Emitter(config: Emitter.Config) { } def getMemberMethodCache( - methodName: MethodName): MethodCache[js.Tree] = { + methodName: MethodName): MethodCache[List[E]] = { _memberMethodCache.getOrElseUpdate(methodName, new MethodCache) } def getStaticLikeMethodCache(namespace: MemberNamespace, - methodName: MethodName): MethodCache[List[js.Tree]] = { + methodName: MethodName): MethodCache[List[E]] = { _methodCaches(namespace.ordinal) .getOrElseUpdate(methodName, new MethodCache) } - def getConstructorCache(): MethodCache[List[js.Tree]] = { + def getConstructorCache(): MethodCache[List[E]] = { _constructorCache.getOrElse { - val cache = new MethodCache[List[js.Tree]] + val cache = new MethodCache[List[E]] _constructorCache = Some(cache) cache } } - def getExportedMemberCache(idx: Int): MethodCache[List[js.Tree]] = + def getExportedMemberCache(idx: Int): MethodCache[List[E]] = _exportedMembersCache.getOrElseUpdate(idx, new MethodCache) def getFullClassCache(): FullClassCache = { @@ -905,11 +954,11 @@ final class Emitter(config: Emitter.Config) { } private class FullClassCache extends knowledgeGuardian.KnowledgeAccessor { - private[this] var _tree: WithGlobals[List[js.Tree]] = null + private[this] var _tree: WithGlobals[List[E]] = null private[this] var _lastVersion: Version = Version.Unversioned - private[this] var _lastCtor: WithGlobals[List[js.Tree]] = null - private[this] var _lastMemberMethods: List[WithGlobals[js.Tree]] = null - private[this] var _lastExportedMembers: List[WithGlobals[List[js.Tree]]] = null + private[this] var _lastCtor: WithGlobals[List[E]] = null + private[this] var _lastMemberMethods: List[WithGlobals[List[E]]] = null + private[this] var _lastExportedMembers: List[WithGlobals[List[E]]] = null private[this] var _cacheUsed = false override def invalidate(): Unit = { @@ -923,9 +972,9 @@ final class Emitter(config: Emitter.Config) { def startRun(): Unit = _cacheUsed = false - def getOrElseUpdate(version: Version, ctor: WithGlobals[List[js.Tree]], - memberMethods: List[WithGlobals[js.Tree]], exportedMembers: List[WithGlobals[List[js.Tree]]], - compute: => WithGlobals[List[js.Tree]]): WithGlobals[List[js.Tree]] = { + def getOrElseUpdate(version: Version, ctor: WithGlobals[List[E]], + memberMethods: List[WithGlobals[List[E]]], exportedMembers: List[WithGlobals[List[E]]], + compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { @tailrec def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean = { @@ -960,11 +1009,11 @@ final class Emitter(config: Emitter.Config) { private class CoreJSLibCache extends knowledgeGuardian.KnowledgeAccessor { private[this] var _lastModuleContext: ModuleContext = _ - private[this] var _lib: WithGlobals[CoreJSLib.Lib] = _ + private[this] var _lib: WithGlobals[CoreJSLib.Lib[List[E]]] = _ - def build(moduleContext: ModuleContext): WithGlobals[CoreJSLib.Lib] = { + def build(moduleContext: ModuleContext): WithGlobals[CoreJSLib.Lib[List[E]]] = { if (_lib == null || _lastModuleContext != moduleContext) { - _lib = CoreJSLib.build(sjsGen, moduleContext, this) + _lib = CoreJSLib.build(sjsGen, postTransform(_, 0), moduleContext, this) _lastModuleContext = moduleContext } _lib @@ -979,9 +1028,9 @@ final class Emitter(config: Emitter.Config) { object Emitter { /** Result of an emitter run. */ - final class Result private[Emitter]( + final class Result[E] private[Emitter]( val header: String, - val body: Map[ModuleID, List[js.Tree]], + val body: Map[ModuleID, List[E]], val footer: String, val topLevelVarDecls: List[String], val globalRefs: Set[String] @@ -1052,22 +1101,26 @@ object Emitter { new Config(coreSpec.semantics, coreSpec.moduleKind, coreSpec.esFeatures) } - private final class DesugaredClassCache { - val privateJSFields = new OneTimeCache[WithGlobals[List[js.Tree]]] - val storeJSSuperClass = new OneTimeCache[WithGlobals[js.Tree]] - val instanceTests = new OneTimeCache[WithGlobals[List[js.Tree]]] - val typeData = new OneTimeCache[WithGlobals[List[js.Tree]]] - val setTypeData = new OneTimeCache[js.Tree] - val moduleAccessor = new OneTimeCache[WithGlobals[List[js.Tree]]] - val staticInitialization = new OneTimeCache[List[js.Tree]] - val staticFields = new OneTimeCache[WithGlobals[List[js.Tree]]] + trait PostTransformer[E] { + def transformStats(trees: List[js.Tree], indent: Int): List[E] + } + + private final class DesugaredClassCache[E >: Null] { + val privateJSFields = new OneTimeCache[WithGlobals[E]] + val storeJSSuperClass = new OneTimeCache[WithGlobals[E]] + val instanceTests = new OneTimeCache[WithGlobals[E]] + val typeData = new OneTimeCache[WithGlobals[E]] + val setTypeData = new OneTimeCache[E] + val moduleAccessor = new OneTimeCache[WithGlobals[E]] + val staticInitialization = new OneTimeCache[E] + val staticFields = new OneTimeCache[WithGlobals[E]] } - private final class GeneratedClass( + private final class GeneratedClass[E]( val className: ClassName, - val main: List[js.Tree], - val staticFields: List[js.Tree], - val staticInitialization: List[js.Tree], + val main: List[E], + val staticFields: List[E], + val staticInitialization: List[E], val trackedGlobalRefs: Set[String] ) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala index 4d675d4dec..a6d632a1cd 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Printers.scala @@ -12,6 +12,8 @@ package org.scalajs.linker.backend.javascript +import java.nio.charset.StandardCharsets + import scala.annotation.switch // Unimport default print and println to avoid invoking them by mistake @@ -31,10 +33,10 @@ import Trees._ object Printers { private val ReusableIndentArray = Array.fill(128)(' '.toByte) - class JSTreePrinter(protected val out: ByteArrayWriter) { + class JSTreePrinter(protected val out: ByteArrayWriter, initIndent: Int = 0) { private final val IndentStep = 2 - private var indentMargin = 0 + private var indentMargin = initIndent * IndentStep private var indentArray = ReusableIndentArray private def indent(): Unit = indentMargin += IndentStep @@ -117,10 +119,15 @@ object Printers { printRow(args, '(', ')') /** Prints a stat including leading indent and trailing newline. */ - final def printStat(tree: Tree): Unit = { - printIndent() - printTree(tree, isStat = true) - println() + final def printStat(tree: Tree): Unit = tree match { + case tree: PrintedTree => + // PrintedTree already contains indent and trailing newline. + print(tree) + + case _ => + printIndent() + printTree(tree, isStat = true) + println() } private def print(tree: Tree): Unit = @@ -750,6 +757,9 @@ object Printers { print("]") } + protected def print(printedTree: PrintedTree): Unit = + out.write(printedTree.jsCode) + private def print(exportName: ExportName): Unit = printEscapeJS(exportName.name) @@ -762,7 +772,8 @@ object Printers { } class JSTreePrinterWithSourceMap(_out: ByteArrayWriter, - sourceMap: SourceMapWriter.Builder) extends JSTreePrinter(_out) { + sourceMap: SourceMapWriter.Builder, initIndent: Int) + extends JSTreePrinter(_out, initIndent) { private var column = 0 @@ -788,6 +799,11 @@ object Printers { sourceMap.endNode(column) } + override protected def print(printedTree: PrintedTree): Unit = { + super.print(printedTree) + sourceMap.insertFragment(printedTree.sourceMapFragment) + } + override protected def println(): Unit = { super.println() sourceMap.nextLine() diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Trees.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Trees.scala index efcf98e609..ec5b72e850 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Trees.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/javascript/Trees.scala @@ -499,4 +499,22 @@ object Trees { from: StringLiteral)( implicit val pos: Position) extends Tree + + /** An already printed tree. + * + * This is a special purpose node to store partially transformed trees. + * + * A cleaner abstraction would be to have something like ir.Tree.Transient + * (for different output formats), but for now, we do not need this. + */ + sealed case class PrintedTree(jsCode: Array[Byte], + sourceMapFragment: SourceMapWriter.Fragment) extends Tree { + val pos: Position = Position.NoPosition + + override def show: String = new String(jsCode, StandardCharsets.UTF_8) + } + + object PrintedTree { + def empty: PrintedTree = PrintedTree(Array(), SourceMapWriter.Fragment.Empty) + } } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/BasicLinkerBackendTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/BasicLinkerBackendTest.scala index 1ce36b9153..2da5fe2324 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/BasicLinkerBackendTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/BasicLinkerBackendTest.scala @@ -20,6 +20,7 @@ import org.junit.Test import org.junit.Assert._ import org.scalajs.ir.Trees._ +import org.scalajs.ir.Version import org.scalajs.junit.async._ @@ -33,17 +34,17 @@ import org.scalajs.logging._ class BasicLinkerBackendTest { import scala.concurrent.ExecutionContext.Implicits.global - private val BackendInvalidatedTopLevelTreesStatsMessage = - raw"""BasicBackend: total top-level trees: (\d+); re-computed: (\d+)""".r + private val BackendInvalidatedPrintedTreesStatsMessage = + raw"""BasicBackend: total top-level printed trees: (\d+); re-computed: (\d+)""".r private val BackendInvalidatedModulesStatsMessage = raw"""BasicBackend: total modules: (\d+); re-written: (\d+)""".r /** Makes sure that linking a "substantial" program (using `println`) twice - * does not invalidate any top-level tree nor module in the second run. + * does not invalidate any module in the second run. */ @Test - def noInvalidatedTopLevelTreeOrModuleInSecondRun(): AsyncResult = await { + def noInvalidatedModuleInSecondRun(): AsyncResult = await { import ModuleSplitStyle._ val classDefs = List( @@ -60,7 +61,7 @@ class BasicLinkerBackendTest { .withModuleSplitStyle(splitStyle) val linker = StandardImpl.linker(config) - val classDefsFiles = classDefs.map(MemClassDefIRFile(_)) + val classDefsFiles = classDefs.map(MemClassDefIRFile(_, Version.fromInt(0))) val initializers = MainTestModuleInitializers val outputDir = MemOutputDirectory() @@ -74,25 +75,6 @@ class BasicLinkerBackendTest { val lines1 = logger1.allLogLines val lines2 = logger2.allLogLines - // Top-level trees - - val Seq(totalTrees1, recomputedTrees1) = - lines1.assertContainsMatch(BackendInvalidatedTopLevelTreesStatsMessage).map(_.toInt) - - val Seq(totalTrees2, recomputedTrees2) = - lines2.assertContainsMatch(BackendInvalidatedTopLevelTreesStatsMessage).map(_.toInt) - - // At the time of writing this test, totalTrees1 reports 382 trees - assertTrue( - s"Not enough total top-level trees (got $totalTrees1); extraction must have gone wrong", - totalTrees1 > 300) - - assertEquals("First run must invalidate every top-level tree", totalTrees1, recomputedTrees1) - assertEquals("Second run must have the same total top-level trees as first run", totalTrees1, totalTrees2) - assertEquals("Second run must not invalidate any top-level tree", 0, recomputedTrees2) - - // Modules - val Seq(totalModules1, rewrittenModules1) = lines1.assertContainsMatch(BackendInvalidatedModulesStatsMessage).map(_.toInt) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala index 17512130bc..935c2a57ae 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala @@ -20,6 +20,7 @@ import org.junit.Test import org.junit.Assert._ import org.scalajs.ir.Trees._ +import org.scalajs.ir.Version import org.scalajs.junit.async._ @@ -128,6 +129,99 @@ class EmitterTest { logger.allLogLines.assertContains(EmitterSetOfDangerousGlobalRefsChangedMessage) } } + + private val EmitterClassTreeCacheStatsMessage = + raw"""Emitter: Class tree cache stats: reused: (\d+) -- invalidated: (\d+)""".r + + private val EmitterMethodTreeCacheStatsMessage = + raw"""Emitter: Method tree cache stats: reused: (\d+) -- invalidated: (\d+)""".r + + private val EmitterPostTransformStatsMessage = + raw"""Emitter: Post transforms: total: (\d+) -- nested: (\d+) -- nested avoided: (\d+)""".r + + /** Makes sure that linking a "substantial" program (using `println`) twice + * does not invalidate any cache or top-level tree in the second run. + */ + @Test + def noInvalidatedCacheOrTopLevelTreeInSecondRun(): AsyncResult = await { + val classDefs = List( + mainTestClassDef(systemOutPrintln(str("Hello world!"))) + ) + + val logger1 = new CapturingLogger + val logger2 = new CapturingLogger + + val config = StandardConfig() + .withCheckIR(true) + .withModuleKind(ModuleKind.ESModule) + + val linker = StandardImpl.linker(config) + val classDefsFiles = classDefs.map(MemClassDefIRFile(_, Version.fromInt(0))) + + val initializers = MainTestModuleInitializers + val outputDir = MemOutputDirectory() + + for { + javalib <- TestIRRepo.javalib + allIRFiles = javalib ++ classDefsFiles + _ <- linker.link(allIRFiles, initializers, outputDir, logger1) + _ <- linker.link(allIRFiles, initializers, outputDir, logger2) + } yield { + val lines1 = logger1.allLogLines + val lines2 = logger2.allLogLines + + // Class tree caches + + val Seq(classCacheReused1, classCacheInvalidated1) = + lines1.assertContainsMatch(EmitterClassTreeCacheStatsMessage).map(_.toInt) + + val Seq(classCacheReused2, classCacheInvalidated2) = + lines2.assertContainsMatch(EmitterClassTreeCacheStatsMessage).map(_.toInt) + + // At the time of writing this test, classCacheInvalidated1 reports 47 + assertTrue( + s"Not enough invalidated class caches (got $classCacheInvalidated1); extraction must have gone wrong", + classCacheInvalidated1 > 40) + + assertEquals("First run must not reuse any class cache", 0, classCacheReused1) + + assertEquals("Second run must reuse all class caches", classCacheReused2, classCacheInvalidated1) + assertEquals("Second run must not invalidate any class cache", 0, classCacheInvalidated2) + + // Method tree caches + + val Seq(methodCacheReused1, methodCacheInvalidated1) = + lines1.assertContainsMatch(EmitterMethodTreeCacheStatsMessage).map(_.toInt) + + val Seq(methodCacheReused2, methodCacheInvalidated2) = + lines2.assertContainsMatch(EmitterMethodTreeCacheStatsMessage).map(_.toInt) + + // At the time of writing this test, methodCacheInvalidated1 reports 107 + assertTrue( + s"Not enough invalidated method caches (got $methodCacheInvalidated1); extraction must have gone wrong", + methodCacheInvalidated1 > 100) + + assertEquals("First run must not reuse any method cache", 0, methodCacheReused1) + + assertEquals("Second run must reuse all method caches", methodCacheReused2, methodCacheInvalidated1) + assertEquals("Second run must not invalidate any method cache", 0, methodCacheInvalidated2) + + // Post transforms + + val Seq(postTransforms1, _, _) = + lines1.assertContainsMatch(EmitterPostTransformStatsMessage).map(_.toInt) + + val Seq(postTransforms2, _, _) = + lines2.assertContainsMatch(EmitterPostTransformStatsMessage).map(_.toInt) + + // At the time of writing this test, postTransformsTotal1 reports 216 + assertTrue( + s"Not enough post transforms (got $postTransforms1); extraction must have gone wrong", + postTransforms1 > 200) + + assertEquals("Second run must not have any post transforms", 0, postTransforms2) + } + } } object EmitterTest { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala index 86c9215f02..ba4848f668 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/backend/javascript/PrintersTest.scala @@ -14,7 +14,7 @@ package org.scalajs.linker.backend.javascript import scala.language.implicitConversions -import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import org.junit.Test import org.junit.Assert._ @@ -35,7 +35,7 @@ class PrintersTest { val printer = new Printers.JSTreePrinter(out) printer.printStat(tree) assertEquals(expected.stripMargin.trim + "\n", - new String(out.toByteArray(), StandardCharsets.UTF_8)) + new String(out.toByteArray(), UTF_8)) } @Test def printFunctionDef(): Unit = { @@ -158,4 +158,24 @@ class PrintersTest { If(BooleanLiteral(true), IntLiteral(2), IntLiteral(3))) ) } + + @Test def showPrintedTree(): Unit = { + val tree = PrintedTree("test".getBytes(UTF_8), SourceMapWriter.Fragment.Empty) + + assertEquals("test", tree.show) + } + + @Test def showNestedPrintedTree(): Unit = { + val tree = PrintedTree(" test\n".getBytes(UTF_8), SourceMapWriter.Fragment.Empty) + + val str = While(BooleanLiteral(false), tree).show + assertEquals( + """ + |while (false) { + | test + |} + """.stripMargin.trim, + str + ) + } } From 5c56042c11adc6df0f830c71d35b053864bc0b66 Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Sun, 17 Dec 2023 18:59:07 +0100 Subject: [PATCH 2/3] Track in Emitter whether a module changed in an incremental run In the next commit, we want to avoid caching entire classes because of the memory cost. However, the BasicLinkerBackend relies on the identity of the generated trees to detect changes: Since that identity will change if we stop caching them, we need to provide an explicit "changed" signal. --- .../closure/ClosureLinkerBackend.scala | 3 +- .../linker/backend/BasicLinkerBackend.scala | 19 +--- .../linker/backend/emitter/Emitter.scala | 98 ++++++++++++------- 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala index 1f532767e2..64160204ac 100644 --- a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala +++ b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala @@ -106,7 +106,8 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config) sjsModule <- moduleSet.modules.headOption } yield { val closureChunk = logger.time("Closure: Create trees)") { - buildChunk(emitterResult.body(sjsModule.id)) + val (trees, _) = emitterResult.body(sjsModule.id) + buildChunk(trees) } logger.time("Closure: Compiler pass") { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala index 4faef57c0a..e07e31597b 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala @@ -88,9 +88,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) val writer = new OutputWriter(output, config, skipContentCheck) { protected def writeModuleWithoutSourceMap(moduleID: ModuleID, force: Boolean): Option[ByteBuffer] = { val cache = printedModuleSetCache.getModuleCache(moduleID) - val printedTrees = emitterResult.body(moduleID) - - val changed = cache.update(printedTrees) + val (printedTrees, changed) = emitterResult.body(moduleID) if (force || changed || allChanged) { rewrittenModules.incrementAndGet() @@ -114,9 +112,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) protected def writeModuleWithSourceMap(moduleID: ModuleID, force: Boolean): Option[(ByteBuffer, ByteBuffer)] = { val cache = printedModuleSetCache.getModuleCache(moduleID) - val printedTrees = emitterResult.body(moduleID) - - val changed = cache.update(printedTrees) + val (printedTrees, changed) = emitterResult.body(moduleID) if (force || changed || allChanged) { rewrittenModules.incrementAndGet() @@ -220,8 +216,6 @@ private object BasicLinkerBackend { private sealed class PrintedModuleCache { private var cacheUsed = false - private var changed = false - private var lastPrintedTrees: List[js.PrintedTree] = Nil private var previousFinalJSFileSize: Int = 0 private var previousFinalSourceMapSize: Int = 0 @@ -239,15 +233,6 @@ private object BasicLinkerBackend { previousFinalSourceMapSize = finalSourceMapSize } - def update(newPrintedTrees: List[js.PrintedTree]): Boolean = { - val changed = !newPrintedTrees.corresponds(lastPrintedTrees)(_ eq _) - this.changed = changed - if (changed) { - lastPrintedTrees = newPrintedTrees - } - changed - } - def cleanAfterRun(): Boolean = { val wasUsed = cacheUsed cacheUsed = false diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index 6035764aca..8aafe7b745 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -112,7 +112,7 @@ final class Emitter[E >: Null <: js.Tree]( } private def emitInternal(moduleSet: ModuleSet, - logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = { + logger: Logger): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = { // Reset caching stats. statsClassesReused = 0 statsClassesInvalidated = 0 @@ -169,7 +169,7 @@ final class Emitter[E >: Null <: js.Tree]( */ @tailrec private def emitAvoidGlobalClash(moduleSet: ModuleSet, - logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, List[E]]] = { + logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = { val result = emitOnce(moduleSet, logger) val mentionedDangerousGlobalRefs = @@ -194,7 +194,7 @@ final class Emitter[E >: Null <: js.Tree]( } private def emitOnce(moduleSet: ModuleSet, - logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = { + logger: Logger): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = { // Genreate classes first so we can measure time separately. val generatedClasses = logger.time("Emitter: Generate Classes") { moduleSet.modules.map { module => @@ -212,18 +212,26 @@ final class Emitter[E >: Null <: js.Tree]( val moduleTrees = logger.time("Emitter: Write trees") { moduleSet.modules.map { module => + var changed = false + def extractChangedAndWithGlobals[T](x: (WithGlobals[T], Boolean)): T = { + changed ||= x._2 + extractWithGlobals(x._1) + } + val moduleContext = ModuleContext.fromModule(module) val moduleCache = state.moduleCaches.getOrElseUpdate(module.id, new ModuleCache) val moduleClasses = generatedClasses(module.id) - val moduleImports = extractWithGlobals { + changed ||= moduleClasses.exists(_.changed) + + val moduleImports = extractChangedAndWithGlobals { moduleCache.getOrComputeImports(module.externalDependencies, module.internalDependencies) { genModuleImports(module).map(postTransform(_, 0)) } } - val topLevelExports = extractWithGlobals { + val topLevelExports = extractChangedAndWithGlobals { /* We cache top level exports all together, rather than individually, * since typically there are few. */ @@ -233,7 +241,7 @@ final class Emitter[E >: Null <: js.Tree]( } } - val moduleInitializers = extractWithGlobals { + val moduleInitializers = extractChangedAndWithGlobals { val initializers = module.initializers.toList moduleCache.getOrComputeInitializers(initializers) { WithGlobals.list(initializers.map { initializer => @@ -324,7 +332,7 @@ final class Emitter[E >: Null <: js.Tree]( trackedGlobalRefs = unionPreserveEmpty(trackedGlobalRefs, genClass.trackedGlobalRefs) } - module.id -> allTrees + module.id -> (allTrees, changed) } } @@ -382,8 +390,14 @@ final class Emitter[E >: Null <: js.Tree]( val classCache = classCaches.getOrElseUpdate( new ClassID(linkedClass.ancestors, moduleContext), new ClassCache) + var changed = false + def extractChanged[T](x: (T, Boolean)): T = { + changed ||= x._2 + x._1 + } + val classTreeCache = - classCache.getCache(linkedClass.version) + extractChanged(classCache.getCache(linkedClass.version)) val kind = linkedClass.kind @@ -396,6 +410,9 @@ final class Emitter[E >: Null <: js.Tree]( withGlobals.value } + def extractWithGlobalsAndChanged[T](x: (WithGlobals[T], Boolean)): T = + extractWithGlobals(extractChanged(x)) + // Main part val main = List.newBuilder[E] @@ -426,7 +443,7 @@ final class Emitter[E >: Null <: js.Tree]( val methodCache = classCache.getStaticLikeMethodCache(namespace, methodDef.methodName) - main ++= extractWithGlobals(methodCache.getOrElseUpdate(methodDef.version, { + main ++= extractWithGlobalsAndChanged(methodCache.getOrElseUpdate(methodDef.version, { classEmitter.genStaticLikeMethod(className, methodDef)(moduleContext, methodCache) .map(postTransform(_, 0)) })) @@ -486,7 +503,7 @@ final class Emitter[E >: Null <: js.Tree]( } // JS constructor - val ctorWithGlobals = { + val ctorWithGlobals = extractChanged { /* The constructor depends both on the class version, and the version * of the inlineable init, if there is one. * @@ -571,13 +588,13 @@ final class Emitter[E >: Null <: js.Tree]( classCache.getMemberMethodCache(method.methodName) val version = Version.combine(isJSClassVersion, method.version) - methodCache.getOrElseUpdate(version, + extractChanged(methodCache.getOrElseUpdate(version, classEmitter.genMemberMethod( className, // invalidated by overall class cache isJSClass, // invalidated by isJSClassVersion useESClass, // invalidated by isJSClassVersion method // invalidated by method.version - )(moduleContext, methodCache).map(postTransform(_, memberIndent))) + )(moduleContext, methodCache).map(postTransform(_, memberIndent)))) } // Exported Members @@ -586,13 +603,13 @@ final class Emitter[E >: Null <: js.Tree]( } yield { val memberCache = classCache.getExportedMemberCache(idx) val version = Version.combine(isJSClassVersion, member.version) - memberCache.getOrElseUpdate(version, + extractChanged(memberCache.getOrElseUpdate(version, classEmitter.genExportedMember( className, // invalidated by overall class cache isJSClass, // invalidated by isJSClassVersion useESClass, // invalidated by isJSClassVersion member // invalidated by version - )(moduleContext, memberCache).map(postTransform(_, memberIndent))) + )(moduleContext, memberCache).map(postTransform(_, memberIndent)))) } val hasClassInitializer: Boolean = { @@ -602,7 +619,7 @@ final class Emitter[E >: Null <: js.Tree]( } } - val fullClass = { + val fullClass = extractChanged { val fullClassCache = classCache.getFullClassCache() fullClassCache.getOrElseUpdate(linkedClass.version, ctorWithGlobals, @@ -714,7 +731,8 @@ final class Emitter[E >: Null <: js.Tree]( main.result(), staticFields, staticInitialization, - trackedGlobalRefs + trackedGlobalRefs, + changed ) } @@ -751,7 +769,7 @@ final class Emitter[E >: Null <: js.Tree]( } def getOrComputeImports(externalDependencies: Set[String], internalDependencies: Set[ModuleID])( - compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { + compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = { _cacheUsed = true @@ -759,20 +777,25 @@ final class Emitter[E >: Null <: js.Tree]( _importsCache = compute _lastExternalDependencies = externalDependencies _lastInternalDependencies = internalDependencies + (_importsCache, true) + } else { + (_importsCache, false) } - _importsCache + } def getOrComputeTopLevelExports(topLevelExports: List[LinkedTopLevelExport])( - compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { + compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = { _cacheUsed = true if (!sameTopLevelExports(topLevelExports, _lastTopLevelExports)) { _topLevelExportsCache = compute _lastTopLevelExports = topLevelExports + (_topLevelExportsCache, true) + } else { + (_topLevelExportsCache, false) } - _topLevelExportsCache } private def sameTopLevelExports(tles1: List[LinkedTopLevelExport], tles2: List[LinkedTopLevelExport]): Boolean = { @@ -803,15 +826,17 @@ final class Emitter[E >: Null <: js.Tree]( } def getOrComputeInitializers(initializers: List[ModuleInitializer.Initializer])( - compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { + compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = { _cacheUsed = true if (initializers != _lastInitializers) { _initializersCache = compute _lastInitializers = initializers + (_initializersCache, true) + } else { + (_initializersCache, false) } - _initializersCache } def cleanAfterRun(): Boolean = { @@ -856,17 +881,18 @@ final class Emitter[E >: Null <: js.Tree]( _fullClassCache.foreach(_.startRun()) } - def getCache(version: Version): DesugaredClassCache[List[E]] = { + def getCache(version: Version): (DesugaredClassCache[List[E]], Boolean) = { + _cacheUsed = true if (_cache == null || !_lastVersion.sameVersion(version)) { invalidate() statsClassesInvalidated += 1 _lastVersion = version _cache = new DesugaredClassCache[List[E]] + (_cache, true) } else { statsClassesReused += 1 + (_cache, false) } - _cacheUsed = true - _cache } def getMemberMethodCache( @@ -932,17 +958,18 @@ final class Emitter[E >: Null <: js.Tree]( def startRun(): Unit = _cacheUsed = false def getOrElseUpdate(version: Version, - v: => WithGlobals[T]): WithGlobals[T] = { + v: => WithGlobals[T]): (WithGlobals[T], Boolean) = { + _cacheUsed = true if (_tree == null || !_lastVersion.sameVersion(version)) { invalidate() statsMethodsInvalidated += 1 _tree = v _lastVersion = version + (_tree, true) } else { statsMethodsReused += 1 + (_tree, false) } - _cacheUsed = true - _tree } def cleanAfterRun(): Boolean = { @@ -974,7 +1001,7 @@ final class Emitter[E >: Null <: js.Tree]( def getOrElseUpdate(version: Version, ctor: WithGlobals[List[E]], memberMethods: List[WithGlobals[List[E]]], exportedMembers: List[WithGlobals[List[E]]], - compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = { + compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = { @tailrec def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean = { @@ -984,6 +1011,8 @@ final class Emitter[E >: Null <: js.Tree]( } } + _cacheUsed = true + if (_tree == null || !version.sameVersion(_lastVersion) || (_lastCtor ne ctor) || !allSame(_lastMemberMethods, memberMethods) || !allSame(_lastExportedMembers, exportedMembers)) { @@ -993,10 +1022,10 @@ final class Emitter[E >: Null <: js.Tree]( _lastCtor = ctor _lastMemberMethods = memberMethods _lastExportedMembers = exportedMembers + (_tree, true) + } else { + (_tree, false) } - - _cacheUsed = true - _tree } def cleanAfterRun(): Boolean = { @@ -1030,7 +1059,7 @@ object Emitter { /** Result of an emitter run. */ final class Result[E] private[Emitter]( val header: String, - val body: Map[ModuleID, List[E]], + val body: Map[ModuleID, (List[E], Boolean)], val footer: String, val topLevelVarDecls: List[String], val globalRefs: Set[String] @@ -1121,7 +1150,8 @@ object Emitter { val main: List[E], val staticFields: List[E], val staticInitialization: List[E], - val trackedGlobalRefs: Set[String] + val trackedGlobalRefs: Set[String], + val changed: Boolean ) private final class OneTimeCache[A >: Null] { From 42efb2aa153b3feea58cda36ee0d25f896e52202 Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Sat, 23 Dec 2023 18:07:04 +0100 Subject: [PATCH 3/3] Do not cache overall class This reduces some memory overhead for negligible performance cost. Residual (post link memory) benchmarks for the test suite: Baseline: 1.13 GB, new 1.01 GB --- .../linker/backend/emitter/Emitter.scala | 119 +++++++++--------- .../org/scalajs/linker/EmitterTest.scala | 11 +- 2 files changed, 70 insertions(+), 60 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index 8aafe7b745..1906ffe84f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -619,37 +619,41 @@ final class Emitter[E >: Null <: js.Tree]( } } - val fullClass = extractChanged { - val fullClassCache = classCache.getFullClassCache() - - fullClassCache.getOrElseUpdate(linkedClass.version, ctorWithGlobals, - memberMethodsWithGlobals, exportedMembersWithGlobals, { - for { - ctor <- ctorWithGlobals - memberMethods <- WithGlobals.flatten(memberMethodsWithGlobals) - exportedMembers <- WithGlobals.flatten(exportedMembersWithGlobals) - allMembers = ctor ::: memberMethods ::: exportedMembers - clazz <- classEmitter.buildClass( - className, // invalidated by overall class cache (part of ancestors) - isJSClass, // invalidated by class version - linkedClass.jsClassCaptures, // invalidated by class version - hasClassInitializer, // invalidated by class version (optimizer cannot remove it) - linkedClass.superClass, // invalidated by class version - storeJSSuperClass, // invalidated by class version - useESClass, // invalidated by class version (depends on kind, config and ancestry only) - allMembers // invalidated directly - )(moduleContext, fullClassCache, linkedClass.pos) // pos invalidated by class version - } yield { - // Avoid a nested post transform if we just got the original members back. - if (clazz eq allMembers) { - statsNestedPostTransformsAvoided += 1 - allMembers - } else { - statsNestedPostTransforms += 1 - postTransform(clazz, 0) - } + val fullClass = { + val fullClassChangeTracker = classCache.getFullClassChangeTracker() + + // Put changed state into a val to avoid short circuiting behavior of ||. + val classChanged = fullClassChangeTracker.trackChanged( + linkedClass.version, ctorWithGlobals, + memberMethodsWithGlobals, exportedMembersWithGlobals) + + changed ||= classChanged + + for { + ctor <- ctorWithGlobals + memberMethods <- WithGlobals.flatten(memberMethodsWithGlobals) + exportedMembers <- WithGlobals.flatten(exportedMembersWithGlobals) + allMembers = ctor ::: memberMethods ::: exportedMembers + clazz <- classEmitter.buildClass( + className, // invalidated by overall class cache (part of ancestors) + isJSClass, // invalidated by class version + linkedClass.jsClassCaptures, // invalidated by class version + hasClassInitializer, // invalidated by class version (optimizer cannot remove it) + linkedClass.superClass, // invalidated by class version + storeJSSuperClass, // invalidated by class version + useESClass, // invalidated by class version (depends on kind, config and ancestry only) + allMembers // invalidated directly + )(moduleContext, fullClassChangeTracker, linkedClass.pos) // pos invalidated by class version + } yield { + // Avoid a nested post transform if we just got the original members back. + if (clazz eq allMembers) { + statsNestedPostTransformsAvoided += 1 + allMembers + } else { + statsNestedPostTransforms += 1 + postTransform(clazz, 0) } - }) + } } main ++= extractWithGlobals(fullClass) @@ -862,7 +866,7 @@ final class Emitter[E >: Null <: js.Tree]( private[this] val _exportedMembersCache = mutable.Map.empty[Int, MethodCache[List[E]]] - private[this] var _fullClassCache: Option[FullClassCache] = None + private[this] var _fullClassChangeTracker: Option[FullClassChangeTracker] = None override def invalidate(): Unit = { /* Do not invalidate contained methods, as they have their own @@ -878,7 +882,7 @@ final class Emitter[E >: Null <: js.Tree]( _methodCaches.foreach(_.valuesIterator.foreach(_.startRun())) _memberMethodCache.valuesIterator.foreach(_.startRun()) _constructorCache.foreach(_.startRun()) - _fullClassCache.foreach(_.startRun()) + _fullClassChangeTracker.foreach(_.startRun()) } def getCache(version: Version): (DesugaredClassCache[List[E]], Boolean) = { @@ -917,10 +921,10 @@ final class Emitter[E >: Null <: js.Tree]( def getExportedMemberCache(idx: Int): MethodCache[List[E]] = _exportedMembersCache.getOrElseUpdate(idx, new MethodCache) - def getFullClassCache(): FullClassCache = { - _fullClassCache.getOrElse { - val cache = new FullClassCache - _fullClassCache = Some(cache) + def getFullClassChangeTracker(): FullClassChangeTracker = { + _fullClassChangeTracker.getOrElse { + val cache = new FullClassChangeTracker + _fullClassChangeTracker = Some(cache) cache } } @@ -934,8 +938,8 @@ final class Emitter[E >: Null <: js.Tree]( _exportedMembersCache.filterInPlace((_, c) => c.cleanAfterRun()) - if (_fullClassCache.exists(!_.cleanAfterRun())) - _fullClassCache = None + if (_fullClassChangeTracker.exists(!_.cleanAfterRun())) + _fullClassChangeTracker = None if (!_cacheUsed) invalidate() @@ -980,28 +984,26 @@ final class Emitter[E >: Null <: js.Tree]( } } - private class FullClassCache extends knowledgeGuardian.KnowledgeAccessor { - private[this] var _tree: WithGlobals[List[E]] = null + private class FullClassChangeTracker extends knowledgeGuardian.KnowledgeAccessor { private[this] var _lastVersion: Version = Version.Unversioned private[this] var _lastCtor: WithGlobals[List[E]] = null private[this] var _lastMemberMethods: List[WithGlobals[List[E]]] = null private[this] var _lastExportedMembers: List[WithGlobals[List[E]]] = null - private[this] var _cacheUsed = false + private[this] var _trackerUsed = false override def invalidate(): Unit = { super.invalidate() - _tree = null _lastVersion = Version.Unversioned _lastCtor = null _lastMemberMethods = null _lastExportedMembers = null } - def startRun(): Unit = _cacheUsed = false + def startRun(): Unit = _trackerUsed = false - def getOrElseUpdate(version: Version, ctor: WithGlobals[List[E]], - memberMethods: List[WithGlobals[List[E]]], exportedMembers: List[WithGlobals[List[E]]], - compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = { + def trackChanged(version: Version, ctor: WithGlobals[List[E]], + memberMethods: List[WithGlobals[List[E]]], + exportedMembers: List[WithGlobals[List[E]]]): Boolean = { @tailrec def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean = { @@ -1011,28 +1013,33 @@ final class Emitter[E >: Null <: js.Tree]( } } - _cacheUsed = true + _trackerUsed = true - if (_tree == null || !version.sameVersion(_lastVersion) || (_lastCtor ne ctor) || - !allSame(_lastMemberMethods, memberMethods) || - !allSame(_lastExportedMembers, exportedMembers)) { + val changed = { + !version.sameVersion(_lastVersion) || + (_lastCtor ne ctor) || + !allSame(_lastMemberMethods, memberMethods) || + !allSame(_lastExportedMembers, exportedMembers) + } + + if (changed) { + // Input has changed or we were invalidated. + // Clean knowledge tracking and re-track dependencies. invalidate() - _tree = compute _lastVersion = version _lastCtor = ctor _lastMemberMethods = memberMethods _lastExportedMembers = exportedMembers - (_tree, true) - } else { - (_tree, false) } + + changed } def cleanAfterRun(): Boolean = { - if (!_cacheUsed) + if (!_trackerUsed) invalidate() - _cacheUsed + _trackerUsed } } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala index 935c2a57ae..1f7884c0f1 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/EmitterTest.scala @@ -208,18 +208,21 @@ class EmitterTest { // Post transforms - val Seq(postTransforms1, _, _) = + val Seq(postTransforms1, nestedPostTransforms1, _) = lines1.assertContainsMatch(EmitterPostTransformStatsMessage).map(_.toInt) - val Seq(postTransforms2, _, _) = + val Seq(postTransforms2, nestedPostTransforms2, _) = lines2.assertContainsMatch(EmitterPostTransformStatsMessage).map(_.toInt) - // At the time of writing this test, postTransformsTotal1 reports 216 + // At the time of writing this test, postTransforms1 reports 216 assertTrue( s"Not enough post transforms (got $postTransforms1); extraction must have gone wrong", postTransforms1 > 200) - assertEquals("Second run must not have any post transforms", 0, postTransforms2) + assertEquals("Second run must only have nested post transforms", + nestedPostTransforms2, postTransforms2) + assertEquals("Both runs must have the same number of nested post transforms", + nestedPostTransforms1, nestedPostTransforms2) } } }