-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][gpu] Add pass for emulating unsupported types. #138087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir-gpu Author: Md Abdullah Shahneous Bari (mshahneo) ChangesThis pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. Patch is 48.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138087.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6cd6f03253aea..0b7339a94b274 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -16,6 +16,8 @@
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
@@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
PatternBenefit benefit = 1);
+/// Set up a type converter to convert unsupported source types to
+/// supported target types.
+void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes,
+ ArrayRef<Type> targetTypes);
+
+/// Collect a set of pattern needed to imitate unsupported source types
+/// using supported target types.
+void populateImitateUnsupportedTypesConversionPatterns(
+ RewritePatternSet &patterns, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes);
+
+/// Set up a dialect conversion to reject operations on unsupported
+/// float types.
+void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
+ TypeConverter &typeConverter);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3766eb16e9429..feb1b2820abd6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
];
}
+def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
+ let summary = "Imitate unsupported types with supported types of same bitwidth.";
+ let description = [{
+ This pass imitates (bitcast/reinterpret_cast) unsupported types
+ with supported types of same bitwidth. The imitation is done
+ by bitcasting the unspported types to the supported types of same bitwidth.
+ Therefore, the source type and destination type must have the same bitwidth.
+ The imitation is done by using the following operations: arith.bitcast.
+
+ The imitation is often needed when the GPU target (dialect/IR) does not
+ support a certain type but the underlying architecture does. Take SPIR-V for
+ example, it does not support bf16, but an underlying architecture (e.g.,
+ intel pvc gpu) that uses SPIR-V for code-generation does.
+ Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+ be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+ kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+ to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+ SPIR-V kernel can then use the imitated type (i16) in the computation.
+ However, i16 is not the same as bf16 (integer vs float), so the computation
+ can not readily use the imitated type (i16).
+
+ Therefore, this transformation pass is intended to be used in conjuction
+ with other transformation passes such as `EmulateUnsupportedFloats` and
+ `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+ vice-versa.
+
+ Finally, usually, there are instructions available in the target
+ (dialect/IR) that can take advantage of these generated patterns
+ (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+ types.
+ For example, Intel provides SPIR-V extension ops that can
+ take imitated bf16 (i16) and convert them to f32 and vice-versa.
+ https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+
+ }];
+
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without type support on a given target">,
+ ListOption<"targetTypeStrs", "target-types", "std::string",
+ "MLIR types to convert the unsupported source types to">,
+ ];
+
+ let dependentDialects = [
+ "::mlir::gpu::GPUDialect",
+ "::mlir::arith::ArithDialect",
+ "::mlir::memref::MemRefDialect"
+ ];
+}
+
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
new file mode 100644
index 0000000000000..c83e6bec568e0
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
@@ -0,0 +1,916 @@
+//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This pass imitates (bitcast/reinterpret_cast) unsupported types
+/// with supported types of same bitwidth. The imitation is done
+/// by bitcasting the unspported types to the supported types of same bitwidth.
+/// Therefore, the source type and destination type must have the same bitwidth.
+/// The imitation is done by using the following operations: arith.bitcast.
+///
+/// The imitation is often needed when the GPU target (dialect/IR) does not
+/// support a certain type but the underlying architecture does. Take SPIR-V for
+/// example, it does not support bf16, but an underlying architecture (e.g.,
+/// intel pvc gpu) that uses SPIR-V for code-generation does.
+/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+/// SPIR-V kernel can then use the imitated type (i16) in the computation.
+/// However, i16 is not the same as bf16 (integer vs float), so the computation
+/// can not readily use the imitated type (i16).
+///
+/// Therefore, this transformation pass is intended to be used in conjuction
+/// with other transformation passes such as `EmulateUnsupportedFloats` and
+/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+/// vice-versa.
+///
+/// Finally, usually, there are instructions available in the target
+/// (dialect/IR) that can take advantage of these generated patterns
+/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+/// types.
+/// For example, Intel provides SPIR-V extension ops that can
+/// take imitated bf16 (i16) and convert them to f32 and vice-versa.
+/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+APFloat bitcastAPIntToAPFloat(const APInt &intValue,
+ const llvm::fltSemantics &semantics) {
+ // Get the bit width of the APInt.
+ unsigned intBitWidth = intValue.getBitWidth();
+ // Get the total bit size required for the APFloat based on the semantics.
+ unsigned floatBitWidth = APFloat::getSizeInBits(semantics);
+ // Ensure the bit widths match for a direct bitcast.
+ assert(intBitWidth == floatBitWidth &&
+ "Bitwidth of APInt and APFloat must match for bitcast");
+
+ // Get the raw bit representation of the APInt as a byte vector.
+ auto intWords = intValue.getRawData();
+ // Create an APFloat with the specified semantics and the raw integer bits.
+ APFloat floatValue(semantics, APInt(intBitWidth, *intWords));
+ return floatValue;
+}
+
+// Get FloatAttr from IntegerAttr.
+FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APInt intVal = intAttr.getValue();
+ auto floatVal = bitcastAPIntToAPFloat(
+ intVal, cast<FloatType>(dstType).getFloatSemantics());
+ return rewriter.getFloatAttr(dstType, floatVal);
+}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
+struct RawAllocator {
+ RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
+
+ std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
+ Value srcMemref) {
+ // Element size in bytes.
+ int64_t elemBitWidth = srcType.getElementTypeBitWidth();
+ int64_t elemByteWidth = (elemBitWidth + 7) / 8;
+
+ if (srcType.hasStaticShape()) {
+ // Static shape: compute total bytes statically.
+ int64_t numElements = 1;
+ for (int64_t dim : srcType.getShape()) {
+ numElements *= dim;
+ }
+ return numElements * elemByteWidth;
+ }
+
+ auto sizes = getSizes(srcType, srcMemref);
+ // Compute number of elements dynamically.
+ Value numElements = sizes.front();
+ for (auto size : llvm::drop_begin(sizes))
+ numElements = builder.create<arith::MulIOp>(loc, numElements, size);
+ Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
+
+ return builder.create<arith::MulIOp>(loc, numElements, elemSize);
+ }
+
+ SmallVector<Value> getSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ } else {
+ sizes.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getIdentityStrides(MemRefType type) {
+ SmallVector<Value> strides;
+ int64_t runningStride = 1;
+ for (int64_t dim : llvm::reverse(type.getShape())) {
+ strides.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, runningStride));
+ if (dim != ShapedType::kDynamic)
+ runningStride *= dim;
+ else
+ runningStride = -1; // not handling dynamic strides.
+ }
+ std::reverse(strides.begin(), strides.end());
+ return strides;
+ }
+
+private:
+ OpBuilder &builder;
+ Location loc;
+};
+
+// Replace uses according to predicates automatically.
+template <typename OpTy>
+void replaceUsesWithPredicate(
+ OpTy originalValue,
+ ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
+ ConversionPatternRewriter &rewriter) {
+
+ for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
+ for (const auto &[predicate, newValue] : replacements) {
+ if (predicate(use)) {
+ use.set(newValue);
+ break;
+ }
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Convertion patterns
+//===----------------------------------------------------------------------===//
+namespace {
+
+//===----------------------------------------------------------------------===//
+// FunctionOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename FuncLikeOp>
+struct ConvertFuncOp final : public OpConversionPattern<FuncLikeOp> {
+ ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern<FuncLikeOp>(context),
+ typeConverter(typeConverter), // Store the reference
+ sourceTypes(sourceTypes), targetTypes(targetTypes),
+ convertedFuncTypes(convertedFuncTypes) {}
+ using OpConversionPattern<FuncLikeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle functions a gpu.module
+ if (!op->template getParentOfType<gpu::GPUModuleOp>())
+ return failure();
+ FunctionType oldFuncType = op.getFunctionType();
+
+ // Convert function signature
+ TypeConverter::SignatureConversion signatureConverter(
+ oldFuncType.getNumInputs());
+ for (const auto &argType :
+ llvm::enumerate(op.getFunctionType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ if (!convertedType)
+ return failure();
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ SmallVector<Type, 4> newResultTypes;
+ for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) {
+ auto convertedType = typeConverter.convertType(resultType.value());
+ if (!convertedType)
+ return failure();
+ newResultTypes.push_back(convertedType);
+ }
+
+ // Convert function signature
+ FunctionType newFuncType = rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), newResultTypes);
+
+ if (!newFuncType)
+ return rewriter.notifyMatchFailure(op, "could not convert function "
+ "type");
+
+ // Create new GPU function with converted type
+ auto newFuncOp =
+ rewriter.create<FuncLikeOp>(op.getLoc(), op.getName(), newFuncType);
+
+ newFuncOp.setVisibility(op.getVisibility());
+ // Copy attributes
+ for (auto attr : op->getAttrs()) {
+ // Skip the function_type attribute since it is already set by
+ // the newFuncType and we don't want to overwrite it.
+ if (attr.getName() != op.getFunctionTypeAttrName() &&
+ attr.getName() != SymbolTable::getSymbolAttrName())
+ newFuncOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ newFuncOp.getRegion().getBlocks().clear();
+ // Inline region approach
+ rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ // Convert block argument types using the type converter
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &signatureConverter))) {
+ return rewriter.notifyMatchFailure(op, "could not convert region "
+ "types");
+ }
+
+ if (!op.use_empty()) {
+ op.emitError("Cannot erase func: still has uses");
+ }
+ for (Operation *user : op->getUsers()) {
+ user->emitRemark() << "User of function " << op.getName();
+ }
+ rewriter.eraseOp(op);
+ // Add the converted function type to the map
+ newFuncOp.getNameAttr().getValue();
+ convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType;
+ return success();
+ }
+
+private:
+ TypeConverter &typeConverter; // Store a reference
+ ArrayRef<Type> sourceTypes;
+ ArrayRef<Type> targetTypes;
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// CallOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertCallOp : OpConversionPattern<func::CallOp> {
+ ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter,
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {}
+
+ LogicalResult
+ matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callee = op.getCalleeAttr();
+
+ auto it = convertedFuncTypes.find(
+ StringAttr::get(callee.getContext(), callee.getValue()));
+ if (it == convertedFuncTypes.end())
+ return rewriter.notifyMatchFailure(
+ op, "Callee signature not converted. Perhaps the callee is not in "
+ "the same gpu module as the caller.");
+
+ auto newResultTypes = it->second.getResults();
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ op, callee.getValue(), newResultTypes, adaptor.getOperands());
+
+ return success();
+ }
+
+private:
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GPULaunchFuncOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::optional<KernelDim3> clusterSizeOpernads =
+ op.hasClusterSize()
+ ? std::optional<gpu::KernelDim3>(op.getClusterSizeOperandValues())
+ : std::nullopt;
+
+ // Create the new launch_func.
+ auto newOp = rewriter.create<gpu::LaunchFuncOp>(
+ op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(),
+ op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(),
+ adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads);
+
+ // Copy block size and grid size attributes
+ newOp->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// AllocOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename Allo...
[truncated]
|
This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as `EmulateUnsupportedFloats` and `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and vice-versa. Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
a31fd58
to
3966b5d
Compare
I believe the usual term is "emulate" (instead of "imitate"), unless I missed a nuance you're trying to make, can you update the description (and code) to reflect this? |
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop | ||
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op | ||
|
||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation only touches specific ops, but from the pass description it's absolutely not clear to me what is the scope here, especially considering we have also EmulateUnsupportedFloats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the description to provide a overall view. Please let me know if it's enough.
High-level question: why is this in GPU? Shouldn't the translation from bf16 to i16 either be a pass over on Arith or part of the SPIR-V lowering? See also, when going to LLVM, we replace all the 8-bit float types with i8 |
%dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16> | ||
|
||
// CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8> | ||
%alloc = gpu.alloc () : memref<10x10xbf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the gpu.alloc
linearized. Lookinag at this example there doesnt seem to be a reason to linearize the allocs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mahesh (@MaheshRavishankar),
We are using the memref.view op to create the a view of the new data type and memref.view requires that the original allocated memref be a flat, contiguous i8 memref with empty layout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the need for alloc linearization. Using arith.bitcast for casting memref types now.
Hi Mehdi (@joker-eph ), Thank you so much for your suggestion. I actually wanted to use Emulation, but the Arith transform pass that extends and truncates the unsupported floats uses this name, I thought it may cause confusion as, this pass does something fundamentally different, although the end goal is similar. Hence, I chose the name Imitation. But I am open to change the name to emulation or any other suggestion the community has. |
Hi Krzysztof(@krzysz00), The reason I wanted keep it in GPU instead of a part of SPIR-V lowering is that, hopefully it can used by other targets as well, not just SPIR-V. As for doing it as part of AithToSPIRV, it would not work, since, it requires modification to both host and device code. That's why to me, gpu seemed like the right place. |
To rephrase my comment: this rewrite + patterns shouldn't be needed If you have a For an example of how this is handled, run an example that operations on Therefore, I recommend rejecting this PR in favor of fixing the SPIR-V lowering patterns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm putting a hold on this
The mapping of bf16 to i16 should be done in the SPIR-V lowering
Also ... why would this ever need to touch host code?
(and |
Move common pass logics to initialize() from runOnOperation().
Thanks for taking the time for explaining your motivations, though I still fundamentally disagree with them. The lowering to SPIR-V should replace all mentions of See, for example, the equivalent code in LLVMTypeConverter Type LLVMTypeConverter::convertFloatType(FloatType type) const {
// Valid LLVM float types are used directly.
if (LLVM::isCompatibleType(type))
return type;
// F4, F6, F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
Float8E8M0FNUType>(type))
return IntegerType::get(&getContext(), type.getWidth());
// Other floating-point types: A custom type conversion rule must be
// specified by the user.
return Type();
} This should correctly convert kernel arguments etc. One other thing we may want to do here is to relax the verifier on |
Thank you so much for your suggestion. I'll take a look. Let me specify my reasoning for this pass approach:
Some of the above mentioned things can be achieved through your approach, but they make the core lowering logic to SPIR-V more convoluted. |
I claim that this pass is not necessary at the GPU level . This issue - where MLIR has a type that the backend doesn't have and it has to be lowered to bytes + special operations - is already handled a bunch in the LLVM case - see everything to do with the 8-bit floats. Yes, it makes the GPU to {LLVM, SPIR-V, what have you} lowering a bit more complicated, but for something like "oh, SPIR-V doesn't have a bfloat type and just calls it i16", the existing pattern is to handle that in the lowering. Secondly, this "make it look like bytes" change isn't generally applicable: only SPIR-V needs it. LLVM (the other target) can handle all this just fine. ... Also, Now, passes that make this transformation more straightforward, like EmulateUnsupportedFloats, I'd be fine with, since it reduces the problem at lowering-time to arith.extf, arith.truncf, and any "native" operations like matmul primitives. If people strongly want a "replace bf16 memrefs with i16 memrefs" pass to add to that suite of preprocessing steps, we could do that, but ... it certainly shouldn't have all this linearization stuff that's cloned from the narrow type emulation (where it's absolutely needed as a pre-processing step because there's no good way to handle sub-byte addressing). This SPIR-V situation ... isn't that, unless I'm missing a whole pile of context. (That being said, |
Remove the usage of memref.view op and the restrictions comes with it. Makes the pass straight forward.
Thanks a lot for your explanation :). |
I'm still not convinced we need this pass instead of fixing the conversion to SPIR-V Do you have a concerte reason why updating the conversion and translation to SPIR-V to represent bf16 as i16 doesn't work? |
The pass is not just for bf16, but any byte addressable data type (e.g., f8). Here is one specific use-case where this pass approach is better for code maintenance:
|
Except that's not actually what's happening here. bf16 is a supported a type on some instructions, you just have to represent it as See also the fact that Like, bf16 is a semantically meaningful input type for the SPIR-V conversion, even though it doesn't exist in the type system itself, just like ANd similarly, operations on 8-bit floats would be operations on ... byte vectors? packed integers? whatever the extension says it is, really, and that's logic for the lowering to handle |
(and removing that code later - or making it conditional based on what extensions exist - is a non-issue in my book. I had to delete this exact sort of bf16-is-i16 thing for AMD LLVM lowering because the backend learned about bfloat - see #108409 for the deletion I still claim that this is a better approach for the SPIR-V lowering than this bitcast-heavy rewrite on GPU) |
I respectfully disagree with your claim. I believe this is a better hands off approach to handle unsupported types. |
Skimmed through the change again, and I think my main concern of the linearization bits have been solved. Just my two cents if it helps resolve the deadlock here. First of all, @mshahneo thanks for working through this and making a fairly flushed out implementation with tests. I fully believe this solves the problem you are encountering, and I am sure other folks have hit this problem as well and would be a good common resource. I have a few follow up questions though. It is not immediately clear to me what is the scope of this change. For example,
In general it would have probably been better to get some community discussion going with RFC. Having an implementation like this actually makes the RFC stronger, but is a better forum to discuss than a PR. Another aspect to consider is the maintainability of the pass. When this lands in main, someone needs to be the defacto owner of it (ideally one of the existing folks who think this is a valuable thing to add and take on the maintainence of this). I know creating RFCs is kind of a pain, but MLIR does suffer a lot with abandoned dialects/methods/transformations that just become a maintainence burden now. |
Thank you so so much, @MaheshRavishankar :). Let me jump to your question first:
It depends. Normally it would not affect the in-tree sub-byte emulation mechanism at all. But let's say a user/vendor have a use case where they want to emulate f4(f4E2M1FN) using i4. The pass would emulate it as i4. If they are used in any arith/math operations, they would have to be replaced by respective vendor-specific operation. Otherwiswe, it would be handled by the SPIR-V converter as a sub-byte integer.
Currently, we handle any non-tree dialect ops using generic conversion logic. That being said, downstream dialects can have other ops with different logic (e.g., arith/math type logic where they might want to keep the usnsupported data types). That's why we exposed the patterns through populateImitateUnsupportedTypesConversionPatterns(). So that the user can utilize it in their own downstream passes if they have special use cases. This is actually similar to other emulation passes (e.g., arith/math). Please let me know if this answers your concern. Please let me know if you have any more concern.
Yes, that is a good point. Again, thank you so so much. |
…marked legal explicitly.
This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast.
The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16).
Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as
EmulateUnsupportedFloats
andExtendUnsupportedTypes
that extend the bitwidth of bf16 to f32 and vice-versa.Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types.
For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
Example:
Let's use an example to show the full flow as to how this pass would interact with other passes (
EmulateUnsupportedFloats
andExtendUnsupportedTypes
) to solve the problem:The following example code (both host and device) does an elementwise bf16 addition:
Op-level emulation:
Now let's say we want to generate code for a target that does not support native bf16 addition, but supports f32 addition. We would use
EmulateUnsupportedFloatsPass
to extend the data types and do the addition, in other words do an op-level emulation. And the result would look like this:Data type emulation/Imitation:
The above emulation is enough if the target IR supports/recognizes bf16 as a data type. However, the target IR may not support/recognize bf16 as a valid data type, which is the case for SPIR-V. Then, we still have no way to lower the above code to SPIR-V. One possible way to handle this would be to imitate/emulate the bf16 data type as a some same bitwidth data type like i16. And change the signature of the kernel as well as arguments passed to the kernel, and do necessary bitcasts. In other words, we would have to do data type emulation/imitation. So, if we run the current pass, the output would look like following:
Conversion Pattern/Pass to Target:
Now during the final conversion to the target IR (SPIR-V), the conversion can identify these bitcast and extf/truncf combination and generate supported operations.
Replace them with :
Replace them with :
Once, this step is done all the bf16 data type is removed from the device code, and the code now can be converted to SPIR-V.
Discussion topic:
Name of the pass (Emulation vs. Imitation) : I actually wanted to use Emulation, but the Arith transform pass that extends and truncates the unsupported floats uses this name, I thought it may cause confusion as, this pass does something fundamentally different, although the end goal is similar. Hence, I chose the name Imitation. But I am open to change the name to emulation or any other suggestion the community has.
[Resolved]: We use the arith.bitcast on memrefs to handle this situation. To handle memref allocation, we utilize the view approach. Create and i8 allocation, then create 2 views: one with the original, one with same shape but supported data type. The reason, we had to resort to this is because there is no other way change the data type of memrefs. I wonder if we can allow the memref.reinterpret_cast to change the data type for similar bitwidth types. Or a new op like memref.bitcast. Having a cast would also remove a lot of handicaps. Because, view has a lot of restrictions:
-- initial allocation has to be flat i8,
-- can not have any layout (empty layout/identity),
--has to be contiguous, so stride support
I can open an RFC if you would like on this topic.