diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..f9c0f982c2118 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1394,6 +1394,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"true", "Allows compiler to assume vector indices fit in 32-bit if that " "yields faster code">, + Option<"useVectorAlignment", "use-vector-alignment", + "bool", /*default=*/"false", + "Use the preferred alignment of a vector type in load/store " + "operations instead of the alignment of the element type of the " + "memref. This flag is intended for use with hardware which requires" + "vector alignment, or in application contexts where it is known all " + "vector access are naturally aligned. ">, Option<"amx", "enable-amx", "bool", /*default=*/"false", "Enables the use of AMX dialect while lowering the vector " diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index 1e29bfeb9c392..f6b09deb4e44c 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns( /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false, bool force32BitVectorIndices = false); + bool reassociateFPReductions = false, bool force32BitVectorIndices = false, + bool useVectorAlignment = false); namespace vector { void registerConvertVectorToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 14cbbac99d9ae..299f198e4ab9c 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op:$reassociate_fp_reductions, - DefaultValuedAttr:$force_32bit_vector_indices); + DefaultValuedAttr:$force_32bit_vector_indices, + DefaultValuedAttr:$use_vector_alignment); let assemblyFormat = "attr-dict"; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 076e5512f375b..5296013189b9e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter, return rewriter.create(loc, val, pos); } +// Helper that returns data layout alignment of a vector. +LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, + VectorType vectorType, unsigned &align) { + Type convertedVectorTy = typeConverter.convertType(vectorType); + if (!convertedVectorTy) + return failure(); + + llvm::LLVMContext llvmContext; + align = LLVM::TypeToLLVMIRTranslator(llvmContext) + .getPreferredAlignment(convertedVectorTy, + typeConverter.getDataLayout()); + + return success(); +} + // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { @@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, return success(); } +// Helper to resolve the alignment for vector load/store, gather and scatter +// ops. If useVectorAlignment is true, get the preferred alignment for the +// vector type in the operation. This option is used for hardware backends with +// vectorization. Otherwise, use the preferred alignment of the element type of +// the memref. Note that if you choose to use vector alignment, the shape of the +// vector type must be resolved before the ConvertVectorToLLVM pass is run. +LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, + VectorType vectorType, + MemRefType memrefType, unsigned &align, + bool useVectorAlignment) { + if (useVectorAlignment) { + if (failed(getVectorAlignment(typeConverter, vectorType, align))) { + return failure(); + } + } else { + if (failed(getMemRefAlignment(typeConverter, memrefType, align))) { + return failure(); + } + } + return success(); +} + // Check if the last stride is non-unit and has a valid memory space. static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter) { @@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, template class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { public: + explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv, + bool useVectorAlign) + : ConvertOpToLLVMPattern(typeConv), + useVectorAlignment(useVectorAlign) {} using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) - return failure(); + if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy, + memRefTy, align, useVectorAlignment))) + return rewriter.notifyMatchFailure(loadOrStoreOp, + "could not resolve alignment"); // Resolve address. auto vtype = cast( @@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { rewriter); return success(); } + +private: + // If true, use the preferred alignment of the vector type. + // If false, use the preferred alignment of the element type + // of the memref. This flag is intended for use with hardware + // backends that require alignment of vector operations. + const bool useVectorAlignment; }; /// Conversion pattern for a vector.gather. class VectorGatherOpConversion : public ConvertOpToLLVMPattern { public: + explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv, + bool useVectorAlign) + : ConvertOpToLLVMPattern(typeConv), + useVectorAlignment(useVectorAlign) {} using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -278,10 +332,9 @@ class VectorGatherOpConversion // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { - return rewriter.notifyMatchFailure(gather, - "could not resolve memref alignment"); - } + if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + memRefType, align, useVectorAlignment))) + return rewriter.notifyMatchFailure(gather, "could not resolve alignment"); // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), @@ -297,12 +350,24 @@ class VectorGatherOpConversion adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); return success(); } + +private: + // If true, use the preferred alignment of the vector type. + // If false, use the preferred alignment of the element type + // of the memref. This flag is intended for use with hardware + // backends that require alignment of vector operations. + const bool useVectorAlignment; }; /// Conversion pattern for a vector.scatter. class VectorScatterOpConversion : public ConvertOpToLLVMPattern { public: + explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv, + bool useVectorAlign) + : ConvertOpToLLVMPattern(typeConv), + useVectorAlignment(useVectorAlign) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -322,10 +387,10 @@ class VectorScatterOpConversion // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { + if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, + memRefType, align, useVectorAlignment))) return rewriter.notifyMatchFailure(scatter, - "could not resolve memref alignment"); - } + "could not resolve alignment"); // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), @@ -340,6 +405,13 @@ class VectorScatterOpConversion rewriter.getI32IntegerAttr(align)); return success(); } + +private: + // If true, use the preferred alignment of the vector type. + // If false, use the preferred alignment of the element type + // of the memref. This flag is intended for use with hardware + // backends that require alignment of vector operations. + const bool useVectorAlignment; }; /// Conversion pattern for a vector.expandload. @@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern( /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions, bool force32BitVectorIndices) { + bool reassociateFPReductions, bool force32BitVectorIndices, + bool useVectorAlignment) { // This function populates only ConversionPatterns, not RewritePatterns. MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(converter, reassociateFPReductions); patterns.add(ctx, force32BitVectorIndices); + patterns.add, + VectorLoadStoreConversion, + VectorLoadStoreConversion, + VectorLoadStoreConversion, + VectorGatherOpConversion, VectorScatterOpConversion>( + converter, useVectorAlignment); patterns.add, - VectorLoadStoreConversion, - VectorLoadStoreConversion, - VectorLoadStoreConversion, - VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, VectorSplatOpLowering, VectorSplatNdOpLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..0ee6dce9ee94b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, force32BitVectorIndices); + converter, patterns, reassociateFPReductions, force32BitVectorIndices, + useVectorAlignment); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 1be436dd7bf41..125c3d918284c 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { populateVectorToLLVMConversionPatterns( static_cast(typeConverter), patterns, - getReassociateFpReductions(), getForce_32bitVectorIndices()); + getReassociateFpReductions(), getForce_32bitVectorIndices(), + getUseVectorAlignment()); } LogicalResult diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir new file mode 100644 index 0000000000000..3fa248656cf3a --- /dev/null +++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN +// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN + + +//===----------------------------------------------------------------------===// +// vector.load +//===----------------------------------------------------------------------===// + +func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// ALL-LABEL: func @load + +// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32> +// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// vector.store +//===----------------------------------------------------------------------===// + +func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) { + %val = arith.constant dense<11.0> : vector<4xf32> + vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32> + return +} + +// ALL-LABEL: func @store + +// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr +// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr + +// ----- + +//===----------------------------------------------------------------------===// +// vector.maskedload +//===----------------------------------------------------------------------===// + +func.func @masked_load(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0: index + %0 = vector.maskedload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} + +// ALL-LABEL: func @masked_load + +// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> +// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// vector.maskedstore +//===----------------------------------------------------------------------===// + +func.func @masked_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { + %c0 = arith.constant 0: index + vector.maskedstore %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> + return +} + +// ALL-LABEL: func @masked_store + +// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr +// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr + +// ----- + +//===----------------------------------------------------------------------===// +// vector.scatter +//===----------------------------------------------------------------------===// + +func.func @scatter(%base: memref, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) { + %0 = arith.constant 0: index + vector.scatter %base[%0][%index], %mask, %value : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// ALL-LABEL: func @scatter + +// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> + +// ----- + +//===----------------------------------------------------------------------===// +// vector.gather +//===----------------------------------------------------------------------===// + +func.func @gather(%base: memref, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %base[%0][%index], %mask, %passthru : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// ALL-LABEL: func @gather + +// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>