Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][gpu] Add pass for emulating unsupported types. #138087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mshahneo
Copy link
Contributor

@mshahneo mshahneo commented May 1, 2025

This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast.

The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16).

Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as EmulateUnsupportedFloats and ExtendUnsupportedTypes that extend the bitwidth of bf16 to f32 and vice-versa.

Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types.
For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op

Example:

Let's use an example to show the full flow as to how this pass would interact with other passes (EmulateUnsupportedFloats and ExtendUnsupportedTypes) to solve the problem:

The following example code (both host and device) does an elementwise bf16 addition:

func.func @host(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
    ...
    %gpu_arg0 = gpu.alloc  host_shared () : memref<10x20xbf16>
    memref.copy %arg0, %gpu_arg0 : memref<10x20xbf16> to memref<10x20xbf16>
    ...

    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1)  args(%gpu_arg0 : memref<10x20xbf16>,...)
    ...
    ...
    return %alloc : memref<10x20xbf16>
  }

gpu.module @test_kernel {
    gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y
      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %2 = arith.addf %0, %1 : bf16
      memref.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
      gpu.return
    }
  }

Op-level emulation:

Now let's say we want to generate code for a target that does not support native bf16 addition, but supports f32 addition. We would use EmulateUnsupportedFloatsPass to extend the data types and do the addition, in other words do an op-level emulation. And the result would look like this:


gpu.module @test_kernel {
    gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y

      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>

      // Emulate the operands of the arith op using f32
      %f32_0 = arith.extf %0  fastmath<contract> : bf16 to f32
      %f32_1 = arith.extf %1  fastmath<contract> : bf16 to f32

      // Do the operation in f32
      %f32_2 = arith.addf %f32_0, %f32_1 : f32

      // Revert the result back to bf16,
      // since the original version returned a bf16 result.
      %2 = arith.truncf %f32_2  fastmath<contract> : f32 to bf16

      memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
      gpu.return
    }
  }

Data type emulation/Imitation:

The above emulation is enough if the target IR supports/recognizes bf16 as a data type. However, the target IR may not support/recognize bf16 as a valid data type, which is the case for SPIR-V. Then, we still have no way to lower the above code to SPIR-V. One possible way to handle this would be to imitate/emulate the bf16 data type as a some same bitwidth data type like i16. And change the signature of the kernel as well as arguments passed to the kernel, and do necessary bitcasts. In other words, we would have to do data type emulation/imitation. So, if we run the current pass, the output would look like following:

func.func @host(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
    ...
    %gpu_arg0 = gpu.alloc  host_shared () : memref<10x20xbf16>
    memref.copy %arg0, %gpu_arg0 : memref<10x20xbf16> to memref<10x20xbf16>

    ...

    // gpu.launch_func is changed with modifed args, i16 is passed insted of bf16
    // Bitcast the gpu.launch_func bf16 arguments to i16 type
    %gpu_arg0_i16 = arith.bitcast % memref<10x20xb16> to memref<10x20xi16>
    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1)  args(%gpu_arg0_i16 : memref<10x20xi16>,...)
    ...
    ...
    return %alloc : memref<10x20xbf16>
}


gpu.module @test_kernel {
    // Kernel signature is modified to use i16 type args instead of bf16 args
    gpu.func @test_kernel(%arg0: memref<10x20xi16>, %arg1: memref<10x20xi16>, %arg2: memref<10x20xi16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y

      // Propagate the usage of i16 type
      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xi16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xi16>

      // Add bit cast operation from i16 to bf16
      %bf16_0 = arith.bitcast %0 : i16 to bf16
      %bf16_1 = arith.bitcast %1 : i16 to bf16
      // Emulate the operands of the arith op using f32
      %f32_0 = arith.extf %bf16_0  fastmath<contract> : bf16 to f32
      %f32_1 = arith.extf %bf16_1  fastmath<contract> : bf16 to f32

      // Do the operation in f32
      %f32_2 = arith.addf %f32_0, %f32_1 : f32

      // Revert the result back to bf16,
      // since the original version returned a bf16 result.
      %bf16_2 = arith.truncf %f32_2  fastmath<contract> : f32 to bf16

      // Cast bf16 type back to i16
      %2 = arith.bitcast %bf16_2 : bf16 to i16

      // Propagate the usage of i16 type
      memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xi16>
      gpu.return
    }
  }

Conversion Pattern/Pass to Target:

Now during the final conversion to the target IR (SPIR-V), the conversion can identify these bitcast and extf/truncf combination and generate supported operations.

  1. bf16 to f32 pattern: Look for the following pattern in device code:
%bf16_0 = arith.bitcast %0 : i16 to bf16
%f32_0 = arith.extf %bf16_0 fastmath<contract> : bf16 to f32

Replace them with :

%f32_0 = spirv.ConvertBF16ToF %0 : i16 to f32
  1. f32 to bf16 pattern: Look for the following pattern in device code:
%bf16_2 = arith.truncf %f32_2 fastmath<contract> : f32 to bf16
%2 = arith.bitcast %bf16_2 : bf16 to i16

Replace them with :

%2 = spirv.ConvertFToBF16 %f32_2 : f32 to i16

Once, this step is done all the bf16 data type is removed from the device code, and the code now can be converted to SPIR-V.

Discussion topic:

  • Name of the pass (Emulation vs. Imitation) : I actually wanted to use Emulation, but the Arith transform pass that extends and truncates the unsupported floats uses this name, I thought it may cause confusion as, this pass does something fundamentally different, although the end goal is similar. Hence, I chose the name Imitation. But I am open to change the name to emulation or any other suggestion the community has.

  • [Resolved]: We use the arith.bitcast on memrefs to handle this situation. To handle memref allocation, we utilize the view approach. Create and i8 allocation, then create 2 views: one with the original, one with same shape but supported data type. The reason, we had to resort to this is because there is no other way change the data type of memrefs. I wonder if we can allow the memref.reinterpret_cast to change the data type for similar bitwidth types. Or a new op like memref.bitcast. Having a cast would also remove a lot of handicaps. Because, view has a lot of restrictions:
    -- initial allocation has to be flat i8,
    -- can not have any layout (empty layout/identity),
    --has to be contiguous, so stride support
    I can open an RFC if you would like on this topic.

@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast.

The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16).

Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as EmulateUnsupportedFloats and ExtendUnsupportedTypes that extend the bitwidth of bf16 to f32 and vice-versa.

Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types.
For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op


Patch is 48.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138087.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+20)
  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.td (+53)
  • (added) mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp (+916)
  • (added) mlir/test/Dialect/GPU/imitate-unsupported-types.mlir (+141)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6cd6f03253aea..0b7339a94b274 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -16,6 +16,8 @@
 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
@@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
     RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
     PatternBenefit benefit = 1);
 
+/// Set up a type converter to convert unsupported source types to
+/// supported target types.
+void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
+                                                  ArrayRef<Type> sourceTypes,
+                                                  ArrayRef<Type> targetTypes);
+
+/// Collect a set of pattern needed to imitate unsupported source types
+/// using supported target types.
+void populateImitateUnsupportedTypesConversionPatterns(
+    RewritePatternSet &patterns, TypeConverter &typeConverter,
+    ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+    DenseMap<StringAttr, FunctionType> &convertedFuncTypes);
+
+/// Set up a dialect conversion to reject operations on unsupported
+/// float types.
+void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
+                                              TypeConverter &typeConverter);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3766eb16e9429..feb1b2820abd6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
   ];
 }
 
+def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
+  let summary = "Imitate unsupported types with supported types of same bitwidth.";
+  let description = [{
+    This pass imitates (bitcast/reinterpret_cast) unsupported types
+    with supported types of same bitwidth. The imitation is done
+    by bitcasting the unspported types to the supported types of same bitwidth.
+    Therefore, the source type and destination type must have the same bitwidth.
+    The imitation is done by using the following operations: arith.bitcast.
+
+    The imitation is often needed when the GPU target (dialect/IR) does not
+    support a certain type but the underlying architecture does. Take SPIR-V for
+    example, it does not support bf16, but an underlying architecture (e.g.,
+    intel pvc gpu) that uses SPIR-V for code-generation does.
+    Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+    be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+    kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+    to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+    SPIR-V kernel can then use the imitated type (i16) in the computation.
+    However, i16 is not the same as bf16 (integer vs float), so the computation
+    can not readily use the imitated type (i16).
+
+    Therefore, this transformation pass is intended to be used in conjuction
+    with other transformation passes such as `EmulateUnsupportedFloats` and
+    `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+    vice-versa.
+
+    Finally, usually, there are instructions available in the target
+    (dialect/IR) that can take advantage of these generated patterns
+    (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+    types.
+    For example, Intel provides SPIR-V extension ops that can
+    take imitated bf16 (i16) and convert them to f32 and vice-versa.
+    https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+    https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+    https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+
+  }];
+
+  let options = [
+    ListOption<"sourceTypeStrs", "source-types", "std::string",
+      "MLIR types without type support on a given target">,
+    ListOption<"targetTypeStrs", "target-types", "std::string",
+      "MLIR types to convert the unsupported source types to">,
+  ];
+
+  let dependentDialects = [
+    "::mlir::gpu::GPUDialect",
+    "::mlir::arith::ArithDialect",
+    "::mlir::memref::MemRefDialect"
+    ];
+}
+
+
 #endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
new file mode 100644
index 0000000000000..c83e6bec568e0
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
@@ -0,0 +1,916 @@
+//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This pass imitates (bitcast/reinterpret_cast) unsupported types
+/// with supported types of same bitwidth. The imitation is done
+/// by bitcasting the unspported types to the supported types of same bitwidth.
+/// Therefore, the source type and destination type must have the same bitwidth.
+/// The imitation is done by using the following operations: arith.bitcast.
+///
+/// The imitation is often needed when the GPU target (dialect/IR) does not
+/// support a certain type but the underlying architecture does. Take SPIR-V for
+/// example, it does not support bf16, but an underlying architecture (e.g.,
+/// intel pvc gpu) that uses SPIR-V for code-generation does.
+/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+/// SPIR-V kernel can then use the imitated type (i16) in the computation.
+/// However, i16 is not the same as bf16 (integer vs float), so the computation
+/// can not readily use the imitated type (i16).
+///
+/// Therefore, this transformation pass is intended to be used in conjuction
+/// with other transformation passes such as `EmulateUnsupportedFloats` and
+/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+/// vice-versa.
+///
+/// Finally, usually, there are instructions available in the target
+/// (dialect/IR) that can take advantage of these generated patterns
+/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+/// types.
+/// For example, Intel provides SPIR-V extension ops that can
+/// take imitated bf16 (i16) and convert them to f32 and vice-versa.
+/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+APFloat bitcastAPIntToAPFloat(const APInt &intValue,
+                              const llvm::fltSemantics &semantics) {
+  // Get the bit width of the APInt.
+  unsigned intBitWidth = intValue.getBitWidth();
+  // Get the total bit size required for the APFloat based on the semantics.
+  unsigned floatBitWidth = APFloat::getSizeInBits(semantics);
+  // Ensure the bit widths match for a direct bitcast.
+  assert(intBitWidth == floatBitWidth &&
+         "Bitwidth of APInt and APFloat must match for bitcast");
+
+  // Get the raw bit representation of the APInt as a byte vector.
+  auto intWords = intValue.getRawData();
+  // Create an APFloat with the specified semantics and the raw integer bits.
+  APFloat floatValue(semantics, APInt(intBitWidth, *intWords));
+  return floatValue;
+}
+
+// Get FloatAttr from IntegerAttr.
+FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType,
+                                      ConversionPatternRewriter &rewriter) {
+  APInt intVal = intAttr.getValue();
+  auto floatVal = bitcastAPIntToAPFloat(
+      intVal, cast<FloatType>(dstType).getFloatSemantics());
+  return rewriter.getFloatAttr(dstType, floatVal);
+}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+                                        ConversionPatternRewriter &rewriter) {
+  APFloat floatVal = floatAttr.getValue();
+  APInt intVal = floatVal.bitcastToAPInt();
+  return rewriter.getIntegerAttr(dstType, intVal);
+}
+
+struct RawAllocator {
+  RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
+
+  std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
+                                                 Value srcMemref) {
+    // Element size in bytes.
+    int64_t elemBitWidth = srcType.getElementTypeBitWidth();
+    int64_t elemByteWidth = (elemBitWidth + 7) / 8;
+
+    if (srcType.hasStaticShape()) {
+      // Static shape: compute total bytes statically.
+      int64_t numElements = 1;
+      for (int64_t dim : srcType.getShape()) {
+        numElements *= dim;
+      }
+      return numElements * elemByteWidth;
+    }
+
+    auto sizes = getSizes(srcType, srcMemref);
+    // Compute number of elements dynamically.
+    Value numElements = sizes.front();
+    for (auto size : llvm::drop_begin(sizes))
+      numElements = builder.create<arith::MulIOp>(loc, numElements, size);
+    Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
+
+    return builder.create<arith::MulIOp>(loc, numElements, elemSize);
+  }
+
+  SmallVector<Value> getSizes(MemRefType type, Value memref) {
+    SmallVector<Value> sizes;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      if (type.isDynamicDim(i)) {
+        sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+      } else {
+        sizes.push_back(
+            builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
+      }
+    }
+    return sizes;
+  }
+
+  SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
+    SmallVector<Value> sizes;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      if (type.isDynamicDim(i)) {
+        sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+      }
+    }
+    return sizes;
+  }
+
+  SmallVector<Value> getIdentityStrides(MemRefType type) {
+    SmallVector<Value> strides;
+    int64_t runningStride = 1;
+    for (int64_t dim : llvm::reverse(type.getShape())) {
+      strides.push_back(
+          builder.create<arith::ConstantIndexOp>(loc, runningStride));
+      if (dim != ShapedType::kDynamic)
+        runningStride *= dim;
+      else
+        runningStride = -1; // not handling dynamic strides.
+    }
+    std::reverse(strides.begin(), strides.end());
+    return strides;
+  }
+
+private:
+  OpBuilder &builder;
+  Location loc;
+};
+
+// Replace uses according to predicates automatically.
+template <typename OpTy>
+void replaceUsesWithPredicate(
+    OpTy originalValue,
+    ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
+    ConversionPatternRewriter &rewriter) {
+
+  for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
+    for (const auto &[predicate, newValue] : replacements) {
+      if (predicate(use)) {
+        use.set(newValue);
+        break;
+      }
+    }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Convertion patterns
+//===----------------------------------------------------------------------===//
+namespace {
+
+//===----------------------------------------------------------------------===//
+// FunctionOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename FuncLikeOp>
+struct ConvertFuncOp final : public OpConversionPattern<FuncLikeOp> {
+  ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter,
+                ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+                DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+      : OpConversionPattern<FuncLikeOp>(context),
+        typeConverter(typeConverter), // Store the reference
+        sourceTypes(sourceTypes), targetTypes(targetTypes),
+        convertedFuncTypes(convertedFuncTypes) {}
+  using OpConversionPattern<FuncLikeOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle functions a gpu.module
+    if (!op->template getParentOfType<gpu::GPUModuleOp>())
+      return failure();
+    FunctionType oldFuncType = op.getFunctionType();
+
+    // Convert function signature
+    TypeConverter::SignatureConversion signatureConverter(
+        oldFuncType.getNumInputs());
+    for (const auto &argType :
+         llvm::enumerate(op.getFunctionType().getInputs())) {
+      auto convertedType = typeConverter.convertType(argType.value());
+      if (!convertedType)
+        return failure();
+      signatureConverter.addInputs(argType.index(), convertedType);
+    }
+    SmallVector<Type, 4> newResultTypes;
+    for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) {
+      auto convertedType = typeConverter.convertType(resultType.value());
+      if (!convertedType)
+        return failure();
+      newResultTypes.push_back(convertedType);
+    }
+
+    // Convert function signature
+    FunctionType newFuncType = rewriter.getFunctionType(
+        signatureConverter.getConvertedTypes(), newResultTypes);
+
+    if (!newFuncType)
+      return rewriter.notifyMatchFailure(op, "could not convert function "
+                                             "type");
+
+    // Create new GPU function with converted type
+    auto newFuncOp =
+        rewriter.create<FuncLikeOp>(op.getLoc(), op.getName(), newFuncType);
+
+    newFuncOp.setVisibility(op.getVisibility());
+    // Copy attributes
+    for (auto attr : op->getAttrs()) {
+      // Skip the function_type attribute since it is already set by
+      // the newFuncType and we don't want to overwrite it.
+      if (attr.getName() != op.getFunctionTypeAttrName() &&
+          attr.getName() != SymbolTable::getSymbolAttrName())
+        newFuncOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    newFuncOp.getRegion().getBlocks().clear();
+    // Inline region approach
+    rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    // Convert block argument types using the type converter
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                           &signatureConverter))) {
+      return rewriter.notifyMatchFailure(op, "could not convert region "
+                                             "types");
+    }
+
+    if (!op.use_empty()) {
+      op.emitError("Cannot erase func: still has uses");
+    }
+    for (Operation *user : op->getUsers()) {
+      user->emitRemark() << "User of function " << op.getName();
+    }
+    rewriter.eraseOp(op);
+    // Add the converted function type to the map
+    newFuncOp.getNameAttr().getValue();
+    convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType;
+    return success();
+  }
+
+private:
+  TypeConverter &typeConverter; // Store a reference
+  ArrayRef<Type> sourceTypes;
+  ArrayRef<Type> targetTypes;
+  DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// CallOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertCallOp : OpConversionPattern<func::CallOp> {
+  ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter,
+                const DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+      : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {}
+
+  LogicalResult
+  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto callee = op.getCalleeAttr();
+
+    auto it = convertedFuncTypes.find(
+        StringAttr::get(callee.getContext(), callee.getValue()));
+    if (it == convertedFuncTypes.end())
+      return rewriter.notifyMatchFailure(
+          op, "Callee signature not converted. Perhaps the callee is not in "
+              "the same gpu module as the caller.");
+
+    auto newResultTypes = it->second.getResults();
+    rewriter.replaceOpWithNewOp<func::CallOp>(
+        op, callee.getValue(), newResultTypes, adaptor.getOperands());
+
+    return success();
+  }
+
+private:
+  const DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GPULaunchFuncOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    std::optional<KernelDim3> clusterSizeOpernads =
+        op.hasClusterSize()
+            ? std::optional<gpu::KernelDim3>(op.getClusterSizeOperandValues())
+            : std::nullopt;
+
+    // Create the new launch_func.
+    auto newOp = rewriter.create<gpu::LaunchFuncOp>(
+        op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(),
+        op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(),
+        adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads);
+
+    // Copy block size and grid size attributes
+    newOp->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AllocOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename Allo...
[truncated]

This pass imitates (bitcast/reinterpret_cast) unsupported types
with supported types of same bitwidth. The imitation is done
by bitcasting the unspported types to the supported types of same bitwidth.
Therefore, the source type and destination type must have the same bitwidth.
The imitation is done by using the following operations: arith.bitcast.

The imitation is often needed when the GPU target (dialect/IR) does not
support a certain type but the underlying architecture does. Take SPIR-V for
example, it does not support bf16, but an underlying architecture (e.g.,
intel pvc gpu) that uses SPIR-V for code-generation does.
Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
SPIR-V kernel can then use the imitated type (i16) in the computation.
However, i16 is not the same as bf16 (integer vs float), so the computation
can not readily use the imitated type (i16).

Therefore, this transformation pass is intended to be used in conjuction
with other transformation passes such as `EmulateUnsupportedFloats` and
`ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
vice-versa.

Finally, usually, there are instructions available in the target
(dialect/IR) that can take advantage of these generated patterns
(bf16->i16->f32, f32->bf16->i16), and convert them to the supported
types.
For example, Intel provides SPIR-V extension ops that can
take imitated bf16 (i16) and convert them to f32 and vice-versa.
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
@mshahneo mshahneo force-pushed the data_type_emulation_refactor branch from a31fd58 to 3966b5d Compare May 1, 2025 07:04
@joker-eph joker-eph changed the title [mlir][gpu] Add pass for imitating unsupported types. [mlir][gpu] Add pass for emulating unsupported types. May 1, 2025
@joker-eph
Copy link
Collaborator

I believe the usual term is "emulate" (instead of "imitate"), unless I missed a nuance you're trying to make, can you update the description (and code) to reflect this?

https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op

}];
Copy link
Collaborator

@joker-eph joker-eph May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation only touches specific ops, but from the pass description it's absolutely not clear to me what is the scope here, especially considering we have also EmulateUnsupportedFloats

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description to provide a overall view. Please let me know if it's enough.

@krzysz00
Copy link
Contributor

krzysz00 commented May 1, 2025

High-level question: why is this in GPU?

Shouldn't the translation from bf16 to i16 either be a pass over on Arith or part of the SPIR-V lowering? See also, when going to LLVM, we replace all the 8-bit float types with i8

%dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16>

// CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8>
%alloc = gpu.alloc () : memref<10x10xbf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the gpu.alloc linearized. Lookinag at this example there doesnt seem to be a reason to linearize the allocs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Mahesh (@MaheshRavishankar),

We are using the memref.view op to create the a view of the new data type and memref.view requires that the original allocated memref be a flat, contiguous i8 memref with empty layout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the need for alloc linearization. Using arith.bitcast for casting memref types now.

@mshahneo
Copy link
Contributor Author

mshahneo commented May 1, 2025

  • I actually wanted use Emulation, but the Arith transform pass that extends and truncates the unsupported floats uses this name, I thought it may cause confusion as, this pass does something fundamentally different, although the end goal is similar. Hence, I chose the name Imitation. But I am open to change the name to emulation or any other suggestion the community has.

Hi Mehdi (@joker-eph ),

Thank you so much for your suggestion. I actually wanted to use Emulation, but the Arith transform pass that extends and truncates the unsupported floats uses this name, I thought it may cause confusion as, this pass does something fundamentally different, although the end goal is similar. Hence, I chose the name Imitation. But I am open to change the name to emulation or any other suggestion the community has.

@mshahneo
Copy link
Contributor Author

mshahneo commented May 1, 2025

High-level question: why is this in GPU?

Shouldn't the translation from bf16 to i16 either be a pass over on Arith or part of the SPIR-V lowering? See also, when going to LLVM, we replace all the 8-bit float types with i8

Hi Krzysztof(@krzysz00),

The reason I wanted keep it in GPU instead of a part of SPIR-V lowering is that, hopefully it can used by other targets as well, not just SPIR-V.

As for doing it as part of AithToSPIRV, it would not work, since, it requires modification to both host and device code.

That's why to me, gpu seemed like the right place.

@krzysz00
Copy link
Contributor

krzysz00 commented May 1, 2025

To rephrase my comment: this rewrite + patterns shouldn't be needed

If you have a bf16 in, say, a memref, then SPIR-V should just make that be a memref of i16 and load from it. Similarly, a vector<2 x bf16> is actually a vector,2 x i16> after conversion to SPIR-V. And arith.truncf %x : f32 to bf16 is just the ConvertToBFloat that returns an i16.

For an example of how this is handled, run an example that operations on f8E4M3FN through teh LLVM lowering. Notice that this type is search-replaced with i8 everywhere, and that you need some platform-specific logic for conversion (as with -arith-to-amdgpu). With SPIR-V, because the conversion is a built-in operation, you don't even need an -arith-to-amdgpu equivalent.

Therefore, I recommend rejecting this PR in favor of fixing the SPIR-V lowering patterns

@krzysz00 krzysz00 self-requested a review May 1, 2025 19:02
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm putting a hold on this

The mapping of bf16 to i16 should be done in the SPIR-V lowering

Also ... why would this ever need to touch host code?

@krzysz00
Copy link
Contributor

krzysz00 commented May 1, 2025

(and ArithToSPIRV alone isn't sufficient - you'll need to touch the SPIR-V type converter and possibly MemRefToSPIRV)

Move common pass logics to initialize() from runOnOperation().
@krzysz00
Copy link
Contributor

krzysz00 commented May 1, 2025

Thanks for taking the time for explaining your motivations, though I still fundamentally disagree with them.

The lowering to SPIR-V should replace all mentions of bf16 with i16, and replace arith.extf %x : bf16 to i32 and arith.truncf %x : f32 to bf16 with operations that produce i16.

See, for example, the equivalent code in LLVMTypeConverter

Type LLVMTypeConverter::convertFloatType(FloatType type) const {
  // Valid LLVM float types are used directly.
  if (LLVM::isCompatibleType(type))
    return type;

  // F4, F6, F8 types are converted to integer types with the same bit width.
  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
          Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
          Float8E8M0FNUType>(type))
    return IntegerType::get(&getContext(), type.getWidth());

  // Other floating-point types: A custom type conversion rule must be
  // specified by the user.
  return Type();
}

This should correctly convert kernel arguments etc.

One other thing we may want to do here is to relax the verifier on gpu.launch_func to allow a bf16/i16 mismatch if it turns out that's an issue. I think that's a better solution than this thing

@mshahneo
Copy link
Contributor Author

mshahneo commented May 2, 2025

Thanks for taking the time for explaining your motivations, though I still fundamentally disagree with them.

The lowering to SPIR-V should replace all mentions of bf16 with i16, and replace arith.extf %x : bf16 to i32 and arith.truncf %x : f32 to bf16 with operations that produce i16.

See, for example, the equivalent code in LLVMTypeConverter

Type LLVMTypeConverter::convertFloatType(FloatType type) const {
  // Valid LLVM float types are used directly.
  if (LLVM::isCompatibleType(type))
    return type;

  // F4, F6, F8 types are converted to integer types with the same bit width.
  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
          Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
          Float8E8M0FNUType>(type))
    return IntegerType::get(&getContext(), type.getWidth());

  // Other floating-point types: A custom type conversion rule must be
  // specified by the user.
  return Type();
}

This should correctly convert kernel arguments etc.

One other thing we may want to do here is to relax the verifier on gpu.launch_func to allow a bf16/i16 mismatch if it turns out that's an issue. I think that's a better solution than this thing

Thank you so much for your suggestion. I'll take a look.
I do understand this approach has some restriction especially with its reliance on view op. Although a memref.bitcast like op would remove all the restrictions.
Could you please elaborate, why you disagree with this approach?

Let me specify my reasoning for this pass approach:

  • One of the main reasons, we chose this approach is due to the flexibility of a pass. Depending on the use case, you can either choose to or not to utilize this pipeline. Keeps the core support un-affected.
  • Portability: being a pass at a GPU dialect-level, it can be used for any backend (LLVM/SPIR-V). Although LLVM already handles this scenario as you pointed out.
  • Not changing fundamental of the existing lowering and verification. (e.g., as you mentioned a scenario gpu.launch_func would have to allow mismatch)
  • Allow for arbitrary source and target types of same bitwidth. (Although I know that the most common scenario is unsupported floats to same bit ints)
  • Makes way to keep the vendor-specific lowering (e.g., Intel-specific bf16 conversion instructions) patterns separate from the core lowering patterns.

Some of the above mentioned things can be achieved through your approach, but they make the core lowering logic to SPIR-V more convoluted.

@krzysz00
Copy link
Contributor

krzysz00 commented May 2, 2025

I claim that this pass is not necessary at the GPU level .

This issue - where MLIR has a type that the backend doesn't have and it has to be lowered to bytes + special operations - is already handled a bunch in the LLVM case - see everything to do with the 8-bit floats. Yes, it makes the GPU to {LLVM, SPIR-V, what have you} lowering a bit more complicated, but for something like "oh, SPIR-V doesn't have a bfloat type and just calls it i16", the existing pattern is to handle that in the lowering.

Secondly, this "make it look like bytes" change isn't generally applicable: only SPIR-V needs it. LLVM (the other target) can handle all this just fine.

... Also, gpu.launch_func should already be comfortable with post-lowering mismatch, given that, in the LLVM pipeline, you get memref<> on the host side but !llvm.ptr on the device side. So I don't think that's a problem either ... and so you might not need any sort of memref.bitcast because the device-side code will be post-memref and so doesn't care. If I'm wrong about that, do let me know - I don't know the SPIR-V lowering all that well. That is, memref<...xbf16> will be some flavor of !spirv.ptr and you can just swap bf16 for i16 during that process.

Now, passes that make this transformation more straightforward, like EmulateUnsupportedFloats, I'd be fine with, since it reduces the problem at lowering-time to arith.extf, arith.truncf, and any "native" operations like matmul primitives.

If people strongly want a "replace bf16 memrefs with i16 memrefs" pass to add to that suite of preprocessing steps, we could do that, but ... it certainly shouldn't have all this linearization stuff that's cloned from the narrow type emulation (where it's absolutely needed as a pre-processing step because there's no good way to handle sub-byte addressing). This SPIR-V situation ... isn't that, unless I'm missing a whole pile of context.

(That being said, memref.reinterpret_elements that'll let you take a memref<...S x bf16> into a memref<...S x i16> is something I'd be OK with, though I don't think we needed here.

Remove the usage of memref.view op and the restrictions
comes with it. Makes the pass straight forward.
@mshahneo
Copy link
Contributor Author

mshahneo commented May 5, 2025

I claim that this pass is not necessary at the GPU level .

This issue - where MLIR has a type that the backend doesn't have and it has to be lowered to bytes + special operations - is already handled a bunch in the LLVM case - see everything to do with the 8-bit floats. Yes, it makes the GPU to {LLVM, SPIR-V, what have you} lowering a bit more complicated, but for something like "oh, SPIR-V doesn't have a bfloat type and just calls it i16", the existing pattern is to handle that in the lowering.

Secondly, this "make it look like bytes" change isn't generally applicable: only SPIR-V needs it. LLVM (the other target) can handle all this just fine.

... Also, gpu.launch_func should already be comfortable with post-lowering mismatch, given that, in the LLVM pipeline, you get memref<> on the host side but !llvm.ptr on the device side. So I don't think that's a problem either ... and so you might not need any sort of memref.bitcast because the device-side code will be post-memref and so doesn't care. If I'm wrong about that, do let me know - I don't know the SPIR-V lowering all that well. That is, memref<...xbf16> will be some flavor of !spirv.ptr and you can just swap bf16 for i16 during that process.

Now, passes that make this transformation more straightforward, like EmulateUnsupportedFloats, I'd be fine with, since it reduces the problem at lowering-time to arith.extf, arith.truncf, and any "native" operations like matmul primitives.

If people strongly want a "replace bf16 memrefs with i16 memrefs" pass to add to that suite of preprocessing steps, we could do that, but ... it certainly shouldn't have all this linearization stuff that's cloned from the narrow type emulation (where it's absolutely needed as a pre-processing step because there's no good way to handle sub-byte addressing). This SPIR-V situation ... isn't that, unless I'm missing a whole pile of context.

(That being said, memref.reinterpret_elements that'll let you take a memref<...S x bf16> into a memref<...S x i16> is something I'd be OK with, though I don't think we needed here.

Thanks a lot for your explanation :).
It turns out arith.bitcast actually supports memrefs. So, we don't really need a new memref element cast op for this case. Updated the PR. Please let me what you think.

@krzysz00
Copy link
Contributor

krzysz00 commented May 6, 2025

I'm still not convinced we need this pass instead of fixing the conversion to SPIR-V

Do you have a concerte reason why updating the conversion and translation to SPIR-V to represent bf16 as i16 doesn't work?

@mshahneo
Copy link
Contributor Author

mshahneo commented May 6, 2025

I'm still not convinced we need this pass instead of fixing the conversion to SPIR-V

Do you have a concerte reason why updating the conversion and translation to SPIR-V to represent bf16 as i16 doesn't work?

The pass is not just for bf16, but any byte addressable data type (e.g., f8). Here is one specific use-case where this pass approach is better for code maintenance:

  • There is an extension from Khronos that adds BF16 data type conditionally for some ops. Once that's added, we may not need to emulate bf16 as i16 to pass launch_func boundary. If we add emulation for bf16 now, it means removing the codes later. Using this pass you don't have this issue.
  • One other major reason for me not to touch the conversion is, I personally think the strict verification and validation SPIR-V has in place for unsupported types should stay in place, and the user should make in informed decision if it wants to emulate.

@krzysz00
Copy link
Contributor

krzysz00 commented May 6, 2025

One other major reason for me not to touch the conversion is, I personally think the strict verification and validation SPIR-V has in place for unsupported types should stay in place, and the user should make in informed decision if it wants to emulate.

Except that's not actually what's happening here. bf16 is a supported a type on some instructions, you just have to represent it as i16.

See also the fact that amdgpu.mfma takes stuff like vector<8xf8E5M2> and turns that into i64 during the lowering

Like, bf16 is a semantically meaningful input type for the SPIR-V conversion, even though it doesn't exist in the type system itself, just like memref.

ANd similarly, operations on 8-bit floats would be operations on ... byte vectors? packed integers? whatever the extension says it is, really, and that's logic for the lowering to handle

@krzysz00
Copy link
Contributor

krzysz00 commented May 6, 2025

(and removing that code later - or making it conditional based on what extensions exist - is a non-issue in my book. I had to delete this exact sort of bf16-is-i16 thing for AMD LLVM lowering because the backend learned about bfloat - see #108409 for the deletion

I still claim that this is a better approach for the SPIR-V lowering than this bitcast-heavy rewrite on GPU)

@mshahneo
Copy link
Contributor Author

mshahneo commented May 6, 2025

(and removing that code later - or making it conditional based on what extensions exist - is a non-issue in my book. I had to delete this exact sort of bf16-is-i16 thing for AMD LLVM lowering because the backend learned about bfloat - see #108409 for the deletion

I still claim that this is a better approach for the SPIR-V lowering than this bitcast-heavy rewrite on GPU)

I respectfully disagree with your claim. I believe this is a better hands off approach to handle unsupported types.
And since at this point there is not technical issue with this approach (I believe I addressed those issues), how do you suggest we resolve this?

@MaheshRavishankar
Copy link
Contributor

Skimmed through the change again, and I think my main concern of the linearization bits have been solved. Just my two cents if it helps resolve the deadlock here.

First of all, @mshahneo thanks for working through this and making a fairly flushed out implementation with tests. I fully believe this solves the problem you are encountering, and I am sure other folks have hit this problem as well and would be a good common resource.

I have a few follow up questions though. It is not immediately clear to me what is the scope of this change. For example,

  1. what would happen for the case of i4 or any sub-byte type. How is this supposed to be handled? This falls within the scope of the sub-byte emulation support that exists in tree.
  2. How would this work for dialects that are not in tree. It seems like for this pass to work it will need to support all operations within the function. But if someone is using this in a function mixing operations from dialects in core, and downstream dialects, then this pass would break. So in that sense this pass is very fixed function.

In general it would have probably been better to get some community discussion going with RFC. Having an implementation like this actually makes the RFC stronger, but is a better forum to discuss than a PR. Another aspect to consider is the maintainability of the pass. When this lands in main, someone needs to be the defacto owner of it (ideally one of the existing folks who think this is a valuable thing to add and take on the maintainence of this). I know creating RFCs is kind of a pain, but MLIR does suffer a lot with abandoned dialects/methods/transformations that just become a maintainence burden now.

@mshahneo
Copy link
Contributor Author

mshahneo commented May 7, 2025

Skimmed through the change again, and I think my main concern of the linearization bits have been solved. Just my two cents if it helps resolve the deadlock here.

First of all, @mshahneo thanks for working through this and making a fairly flushed out implementation with tests. I fully believe this solves the problem you are encountering, and I am sure other folks have hit this problem as well and would be a good common resource.

I have a few follow up questions though. It is not immediately clear to me what is the scope of this change. For example,

  1. what would happen for the case of i4 or any sub-byte type. How is this supposed to be handled? This falls within the scope of the sub-byte emulation support that exists in tree.
  2. How would this work for dialects that are not in tree. It seems like for this pass to work it will need to support all operations within the function. But if someone is using this in a function mixing operations from dialects in core, and downstream dialects, then this pass would break. So in that sense this pass is very fixed function.

In general it would have probably been better to get some community discussion going with RFC. Having an implementation like this actually makes the RFC stronger, but is a better forum to discuss than a PR. Another aspect to consider is the maintainability of the pass. When this lands in main, someone needs to be the defacto owner of it (ideally one of the existing folks who think this is a valuable thing to add and take on the maintainence of this). I know creating RFCs is kind of a pain, but MLIR does suffer a lot with abandoned dialects/methods/transformations that just become a maintainence burden now.

Thank you so so much, @MaheshRavishankar :).

Let me jump to your question first:

  1. what would happen for the case of i4 or any sub-byte type. How is this supposed to be handled? This falls within the scope of the sub-byte emulation support that exists in tree.

It depends. Normally it would not affect the in-tree sub-byte emulation mechanism at all. But let's say a user/vendor have a use case where they want to emulate f4(f4E2M1FN) using i4. The pass would emulate it as i4. If they are used in any arith/math operations, they would have to be replaced by respective vendor-specific operation. Otherwiswe, it would be handled by the SPIR-V converter as a sub-byte integer.
Does it answer your concern?

  1. How would this work for dialects that are not in tree. It seems like for this pass to work it will need to support all operations within the function. But if someone is using this in a function mixing operations from dialects in core, and downstream dialects, then this pass would break. So in that sense this pass is very fixed function.

Currently, we handle any non-tree dialect ops using generic conversion logic.

That being said, downstream dialects can have other ops with different logic (e.g., arith/math type logic where they might want to keep the usnsupported data types). That's why we exposed the patterns through populateImitateUnsupportedTypesConversionPatterns(). So that the user can utilize it in their own downstream passes if they have special use cases. This is actually similar to other emulation passes (e.g., arith/math).

Please let me know if this answers your concern.

Please let me know if you have any more concern.

Another aspect to consider is the maintainability of the pass. When this lands in main, someone needs to be the defacto owner of it (ideally one of the existing folks who think this is a valuable thing to add and take on the maintainence of this).

Yes, that is a good point.
I can open an RFC if you want.

Again, thank you so so much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants