Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Opt/Wasm: Add a number of Wasm-specific intrinsics and transients.
The motivation is mainly to get intrinsics for the bit-conversions
between integers and floating point numbers, as well as for
`numberLeadingZeros`. These are building blocks for many other
low-level operations, and their JS-builtin-based implementation is
really bad on Wasm for those use cases.

Once we have the infrastructure for those as transients in the Wasm
backend, we also take the opportunity to add a series of other
methods that have a direct Wasm opcode equivalent.
  • Loading branch information
sjrd committed Aug 24, 2024
commit 7276564fb8277de3e22b961232d76006ed6c96a6
2 changes: 2 additions & 0 deletions javalib/src/main/scala/java/lang/Double.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,11 @@ object Double {
@inline def hashCode(value: scala.Double): Int =
FloatingPointBits.numberHashCode(value)

// Wasm intrinsic
@inline def longBitsToDouble(bits: scala.Long): scala.Double =
FloatingPointBits.longBitsToDouble(bits)

// Wasm intrinsic
@inline def doubleToLongBits(value: scala.Double): scala.Long =
FloatingPointBits.doubleToLongBits(value)

Expand Down
2 changes: 2 additions & 0 deletions javalib/src/main/scala/java/lang/Float.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,11 @@ object Float {
@inline def hashCode(value: scala.Float): Int =
FloatingPointBits.numberHashCode(value)

// Wasm intrinsic
@inline def intBitsToFloat(bits: scala.Int): scala.Float =
FloatingPointBits.intBitsToFloat(bits)

// Wasm intrinsic
@inline def floatToIntBits(value: scala.Float): scala.Int =
FloatingPointBits.floatToIntBits(value)

Expand Down
7 changes: 7 additions & 0 deletions javalib/src/main/scala/java/lang/Integer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ object Integer {
@inline def toUnsignedLong(x: Int): scala.Long =
x.toLong & 0xffffffffL

// Wasm intrinsic
def bitCount(i: scala.Int): scala.Int = {
/* See http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
*
Expand All @@ -219,10 +220,12 @@ object Integer {
(((t2 + (t2 >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24
}

// Wasm intrinsic
@inline def divideUnsigned(dividend: Int, divisor: Int): Int =
if (divisor == 0) 0 / 0
else asInt(asUint(dividend) / asUint(divisor))

// Wasm intrinsic
@inline def remainderUnsigned(dividend: Int, divisor: Int): Int =
if (divisor == 0) 0 % 0
else asInt(asUint(dividend) % asUint(divisor))
Expand Down Expand Up @@ -263,15 +266,18 @@ object Integer {
reverseBytes((k & 0x0F0F0F0F) << 4 | (k >> 4) & 0x0F0F0F0F)
}

// Wasm intrinsic
@inline def rotateLeft(i: scala.Int, distance: scala.Int): scala.Int =
(i << distance) | (i >>> -distance)

// Wasm intrinsic
@inline def rotateRight(i: scala.Int, distance: scala.Int): scala.Int =
(i >>> distance) | (i << -distance)

@inline def signum(i: scala.Int): scala.Int =
if (i == 0) 0 else if (i < 0) -1 else 1

// Intrinsic, fallback on actual code for non-literal in JS
@inline def numberOfLeadingZeros(i: scala.Int): scala.Int = {
if (linkingInfo.esVersion >= ESVersion.ES2015) js.Math.clz32(i)
else clz32Dynamic(i)
Expand All @@ -296,6 +302,7 @@ object Integer {
}
}

// Wasm intrinsic
@inline def numberOfTrailingZeros(i: scala.Int): scala.Int =
if (i == 0) 32
else 31 - numberOfLeadingZeros(i & -i)
Expand Down
9 changes: 7 additions & 2 deletions javalib/src/main/scala/java/lang/Long.scala
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ object Long {
@inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int =
compare(x ^ SignBit, y ^ SignBit)

// Intrinsic
// Intrinsic, except for JS when using bigint's for longs
def divideUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long =
divModUnsigned(dividend, divisor, isDivide = true)

// Intrinsic
// Intrinsic, except for JS when using bigint's for longs
def remainderUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long =
divModUnsigned(dividend, divisor, isDivide = false)

Expand Down Expand Up @@ -408,6 +408,7 @@ object Long {
if (lo != 0) 0 else Integer.lowestOneBit(hi))
}

// Wasm intrinsic
@inline
def bitCount(i: scala.Long): scala.Int = {
val lo = i.toInt
Expand Down Expand Up @@ -436,10 +437,12 @@ object Long {
private def makeLongFromLoHi(lo: Int, hi: Int): scala.Long =
(lo.toLong & 0xffffffffL) | (hi.toLong << 32)

// Wasm intrinsic
@inline
def rotateLeft(i: scala.Long, distance: scala.Int): scala.Long =
(i << distance) | (i >>> -distance)

// Wasm intrinsic
@inline
def rotateRight(i: scala.Long, distance: scala.Int): scala.Long =
(i >>> distance) | (i << -distance)
Expand All @@ -452,13 +455,15 @@ object Long {
else 1
}

// Wasm intrinsic
@inline
def numberOfLeadingZeros(l: scala.Long): Int = {
val hi = (l >>> 32).toInt
if (hi != 0) Integer.numberOfLeadingZeros(hi)
else Integer.numberOfLeadingZeros(l.toInt) + 32
}

// Wasm intrinsic
@inline
def numberOfTrailingZeros(l: scala.Long): Int = {
val lo = l.toInt
Expand Down
10 changes: 10 additions & 0 deletions javalib/src/main/scala/java/lang/Math.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,30 @@ object Math {

@inline def abs(a: scala.Int): scala.Int = if (a < 0) -a else a
@inline def abs(a: scala.Long): scala.Long = if (a < 0) -a else a

// Wasm intrinsics
@inline def abs(a: scala.Float): scala.Float = js.Math.abs(a).toFloat
@inline def abs(a: scala.Double): scala.Double = js.Math.abs(a)

@inline def max(a: scala.Int, b: scala.Int): scala.Int = if (a > b) a else b
@inline def max(a: scala.Long, b: scala.Long): scala.Long = if (a > b) a else b

// Wasm intrinsics
@inline def max(a: scala.Float, b: scala.Float): scala.Float = js.Math.max(a, b).toFloat
@inline def max(a: scala.Double, b: scala.Double): scala.Double = js.Math.max(a, b)

@inline def min(a: scala.Int, b: scala.Int): scala.Int = if (a < b) a else b
@inline def min(a: scala.Long, b: scala.Long): scala.Long = if (a < b) a else b

// Wasm intrinsics
@inline def min(a: scala.Float, b: scala.Float): scala.Float = js.Math.min(a, b).toFloat
@inline def min(a: scala.Double, b: scala.Double): scala.Double = js.Math.min(a, b)

// Wasm intrinsics
@inline def ceil(a: scala.Double): scala.Double = js.Math.ceil(a)
@inline def floor(a: scala.Double): scala.Double = js.Math.floor(a)

// Wasm intrinsic
def rint(a: scala.Double): scala.Double = {
val rounded = js.Math.round(a)
val mod = a % 1.0
Expand All @@ -60,7 +68,9 @@ object Math {
@inline def round(a: scala.Float): scala.Int = js.Math.round(a).toInt
@inline def round(a: scala.Double): scala.Long = js.Math.round(a).toLong

// Wasm intrinsic
@inline def sqrt(a: scala.Double): scala.Double = js.Math.sqrt(a)

@inline def pow(a: scala.Double, b: scala.Double): scala.Double = js.Math.pow(a, b)

@inline def exp(a: scala.Double): scala.Double = js.Math.exp(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2939,6 +2939,19 @@ private class FunctionEmitter private (
fb += wa.Call(genFunctionID.anyGetClassName)
StringType

case value @ WasmTransients.WasmUnaryOp(_, lhs) =>
genTreeAuto(lhs)
markPosition(tree)
fb += value.wasmInstr
value.tpe

case value @ WasmTransients.WasmBinaryOp(_, lhs, rhs) =>
genTreeAuto(lhs)
genTreeAuto(rhs)
markPosition(tree)
fb += value.wasmInstr
value.tpe

case other =>
throw new AssertionError(s"Unknown transient: $other")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Scala.js (https://www.scala-js.org/)
*
* Copyright EPFL.
*
* Licensed under Apache License 2.0
* (https://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package org.scalajs.linker.backend.wasmemitter

import scala.annotation.switch

import org.scalajs.ir.Position
import org.scalajs.ir.Printers._
import org.scalajs.ir.Transformers._
import org.scalajs.ir.Traversers._
import org.scalajs.ir.Trees._
import org.scalajs.ir.Types._

import org.scalajs.linker.backend.webassembly.{Instructions => wa}

/** Transients generated by the optimizer that only makes sense in Wasm. */
object WasmTransients {

/** Wasm unary op.
*
* Wasm features a number of dedicated opcodes for operations that are not
* in the IR, but only implemented in user space. We can see `WasmUnaryOp`
* as an extension of `ir.Trees.UnaryOp` that covers those.
*
* Wasm unary ops always preserve pureness.
*/
final case class WasmUnaryOp(op: WasmUnaryOp.Code, lhs: Tree)
extends Transient.Value {
import WasmUnaryOp._

val tpe: Type = resultTypeOf(op)

def traverse(traverser: Traverser): Unit =
traverser.traverse(lhs)

def transform(transformer: Transformer, isStat: Boolean)(
implicit pos: Position): Tree = {
Transient(WasmUnaryOp(op, transformer.transformExpr(lhs)))
}

def wasmInstr: wa.SimpleInstr = (op: @switch) match {
case I32Clz => wa.I32Clz
case I32Ctz => wa.I32Ctz
case I32Popcnt => wa.I32Popcnt

case I64Clz => wa.I64Clz
case I64Ctz => wa.I64Ctz
case I64Popcnt => wa.I64Popcnt

case F32Abs => wa.F32Abs

case F64Abs => wa.F64Abs
case F64Ceil => wa.F64Ceil
case F64Floor => wa.F64Floor
case F64Nearest => wa.F64Nearest
case F64Sqrt => wa.F64Sqrt

case I32ReinterpretF32 => wa.I32ReinterpretF32
case I64ReinterpretF64 => wa.I64ReinterpretF64
case F32ReinterpretI32 => wa.F32ReinterpretI32
case F64ReinterpretI64 => wa.F64ReinterpretI64
}

def printIR(out: IRTreePrinter): Unit = {
out.print("$")
out.print(wasmInstr.mnemonic)
out.printArgs(List(lhs))
}
}

object WasmUnaryOp {
/** Codes are raw Ints to be able to write switch matches on them. */
type Code = Int

final val I32Clz = 1
final val I32Ctz = 2
final val I32Popcnt = 3

final val I64Clz = 4
final val I64Ctz = 5
final val I64Popcnt = 6

final val F32Abs = 7

final val F64Abs = 8
final val F64Ceil = 9
final val F64Floor = 10
final val F64Nearest = 11
final val F64Sqrt = 12

final val I32ReinterpretF32 = 13
final val I64ReinterpretF64 = 14
final val F32ReinterpretI32 = 15
final val F64ReinterpretI64 = 16

def resultTypeOf(op: Code): Type = (op: @switch) match {
case I32Clz | I32Ctz | I32Popcnt | I32ReinterpretF32 =>
IntType

case I64Clz | I64Ctz | I64Popcnt | I64ReinterpretF64 =>
LongType

case F32Abs | F32ReinterpretI32 =>
FloatType

case F64Abs | F64Ceil | F64Floor | F64Nearest | F64Sqrt | F64ReinterpretI64 =>
DoubleType
}
}

/** Wasm binary op.
*
* Wasm features a number of dedicated opcodes for operations that are not
* in the IR, but only implemented in user space. We can see `WasmBinaryOp`
* as an extension of `ir.Trees.BinaryOp` that covers those.
*
* Unsigned divisions and remainders exhibit always-unchecked undefined
* behavior when their rhs is 0. It is up to code generating those transient
* nodes to check for 0 themselves if necessary.
*
* All other Wasm binary ops preserve pureness.
*/
final case class WasmBinaryOp(op: WasmBinaryOp.Code, lhs: Tree, rhs: Tree)
extends Transient.Value {
import WasmBinaryOp._

val tpe: Type = resultTypeOf(op)

def traverse(traverser: Traverser): Unit = {
traverser.traverse(lhs)
traverser.traverse(rhs)
}

def transform(transformer: Transformer, isStat: Boolean)(
implicit pos: Position): Tree = {
Transient(WasmBinaryOp(op, transformer.transformExpr(lhs),
transformer.transformExpr(rhs)))
}

def wasmInstr: wa.SimpleInstr = (op: @switch) match {
case I32DivU => wa.I32DivU
case I32RemU => wa.I32RemU
case I32Rotl => wa.I32Rotl
case I32Rotr => wa.I32Rotr

case I64DivU => wa.I64DivU
case I64RemU => wa.I64RemU
case I64Rotl => wa.I64Rotl
case I64Rotr => wa.I64Rotr

case F32Min => wa.F32Min
case F32Max => wa.F32Max

case F64Min => wa.F64Min
case F64Max => wa.F64Max
}

def printIR(out: IRTreePrinter): Unit = {
out.print("$")
out.print(wasmInstr.mnemonic)
out.printArgs(List(lhs, rhs))
}
}

object WasmBinaryOp {
/** Codes are raw Ints to be able to write switch matches on them. */
type Code = Int

final val I32DivU = 1
final val I32RemU = 2
final val I32Rotl = 3
final val I32Rotr = 4

final val I64DivU = 5
final val I64RemU = 6
final val I64Rotl = 7
final val I64Rotr = 8

final val F32Min = 9
final val F32Max = 10

final val F64Min = 11
final val F64Max = 12

def resultTypeOf(op: Code): Type = (op: @switch) match {
case I32DivU | I32RemU | I32Rotl | I32Rotr =>
IntType

case I64DivU | I64RemU | I64Rotl | I64Rotr =>
LongType

case F32Min | F32Max =>
FloatType

case F64Min | F64Max =>
DoubleType
}
}
}
Loading