-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg #154420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesThis PR adds distribution patterns for scatter ops (LoadGather and StoreScatter) with offsets. Full diff: https://github.com/llvm/llvm-project/pull/154420.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..44a72ed9435cc 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -767,6 +767,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Type": $value, "Value": $source,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -875,6 +881,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Value": $value, "Value": $dest,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..eed1b4e4ff62a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -796,6 +796,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -844,6 +860,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..c063cb7d5d8b6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -884,6 +884,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+ : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ SmallVector<Value> newLoadOps;
+ auto chunkSizeAttr =
+ rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+ VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+ for (auto [offsets, mask] :
+ llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ xegpu::setLayoutAttr(newLoadOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+ : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+ if (!valueType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = valueType.getShape();
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ auto chunkSizeOpt = op.getChunkSize();
+ int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+ for (auto [val, offs, mask] : llvm::zip(
+ adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+ if (auto layoutAttr =
+ op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ op->setAttr("layout_result_0", newLayout);
+ }
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -895,7 +977,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
- WgToSgArithConstantOp>(patterns.getContext());
+ WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+ WgToSgStoreScatterOpWithOffset>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1015,6 +1098,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+ [=](xegpu::LoadGatherOp op) -> bool {
+ auto layout = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+ [=](xegpu::StoreScatterOp op) -> bool {
+ // Check if the layout attribute is present on the result.
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+ if (!layout)
+ return true;
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 07a0b86223c33..03fd95f0e778e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -263,4 +263,45 @@ gpu.module @test_distribution {
} {sg_id_range = #xegpu.range<[3, 19]>}
gpu.return
}
+
+ // CHECK-LABEL: @load_gather
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_gather(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @store_scatter
+ // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
+ gpu.func @store_scatter(%dest : memref<256xf16>) {
+ // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+ %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+ : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @load_with_chunk_size
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
+ gpu.return
+ }
}
|
ping for review |
auto chunkSizeAttr = | ||
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); | ||
VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); | ||
for (auto [offsets, mask] : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Here the code assums the offset has been distributed by its defining op. It is not always true currently,
e.g., the offsets is from an function parameter, or non-splat arith constant. Thus, it would be better to check
whether the offsets has been distributed yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, could you add a test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the current design, its not possible to add negative tests
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), | ||
op.getL2HintAttr(), op.getL3HintAttr()); | ||
// Update the layout_result_0 attribute to drop sg_layout and sg_data. | ||
if (auto layoutAttr = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use the getLayoutAttr and setLayoutAttr interface here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
store doesn't have a result, and the layout is attached to the op, so not sure if getters/setters will work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then, what is the purpose of setting "layout_result_0" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is to retain lane layout & lane data if its present...maybe we need to name it differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since StoreScatterOp doesn't have result. There should be no layout_result_0 for this op from the input code. but it could have layout_operand_0
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to "layout", similar to load/store matrix ops
auto chunkSizeOpt = op.getChunkSize(); | ||
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1; | ||
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); | ||
for (auto [val, offs, mask] : llvm::zip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same check for offsets as above.
auto chunkSizeAttr = | ||
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); | ||
VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); | ||
for (auto [offsets, mask] : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, could you add a test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. thanks for the efforts.
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; | ||
|
||
// The offsets need to be distributed | ||
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it would read better if we do dyn_cast before the if. and same for the following one.
This PR adds distribution patterns for scatter ops (LoadGather and StoreScatter) with offsets.