Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Aug 19, 2025

This PR adds distribution patterns for scatter ops (LoadGather and StoreScatter) with offsets.

@nbpatel nbpatel requested review from chencha3 and adam-smnk August 19, 2025 20:21
@nbpatel nbpatel marked this pull request as ready for review August 19, 2025 22:50
@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

This PR adds distribution patterns for scatter ops (LoadGather and StoreScatter) with offsets.


Full diff: https://github.com/llvm/llvm-project/pull/154420.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+12)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+34)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+99-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+41)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..44a72ed9435cc 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -767,6 +767,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
@@ -875,6 +881,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest, 
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..eed1b4e4ff62a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -796,6 +796,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
         l1_hint, l2_hint, l3_hint);
 }
 
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_StoreScatterOp
 //===----------------------------------------------------------------------===//
@@ -844,6 +860,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
         l2_hint, l3_hint);
 }
 
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+                           Value value, Value dest,
+                           ArrayRef<OpFoldResult> offsets, Value mask,
+                           IntegerAttr chunk_size,
+                           xegpu::CachePolicyAttr l1_hint,
+                           xegpu::CachePolicyAttr l2_hint,
+                           xegpu::CachePolicyAttr l3_hint) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_UpdateOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..c063cb7d5d8b6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -884,6 +884,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+    : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    SmallVector<Value> newLoadOps;
+    auto chunkSizeAttr =
+        rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+    VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+    for (auto [offsets, mask] :
+         llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      xegpu::setLayoutAttr(newLoadOp->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+    : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+    if (!valueType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = valueType.getShape();
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    auto chunkSizeOpt = op.getChunkSize();
+    int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+    auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+    for (auto [val, offs, mask] : llvm::zip(
+             adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+      if (auto layoutAttr =
+              op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          op->setAttr("layout_result_0", newLayout);
+      }
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -895,7 +977,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp>(patterns.getContext());
+           WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+           WgToSgStoreScatterOpWithOffset>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1015,6 +1098,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+      [=](xegpu::LoadGatherOp op) -> bool {
+        auto layout = xegpu::getLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+      [=](xegpu::StoreScatterOp op) -> bool {
+        // Check if the layout attribute is present on the result.
+        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+        if (!layout)
+          return true;
+        return isLegal(layout);
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 07a0b86223c33..03fd95f0e778e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -263,4 +263,45 @@ gpu.module @test_distribution {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: @load_gather
+  // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+  gpu.func @load_gather(%src : memref<?xf16>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
+    // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+      : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @store_scatter
+  // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
+  gpu.func @store_scatter(%dest : memref<256xf16>) {
+    // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+      : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @load_with_chunk_size
+  // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+  gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+      : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
+    gpu.return
+  }
 }

@nbpatel
Copy link
Contributor Author

nbpatel commented Aug 22, 2025

ping for review

auto chunkSizeAttr =
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Here the code assums the offset has been distributed by its defining op. It is not always true currently,
e.g., the offsets is from an function parameter, or non-splat arith constant. Thus, it would be better to check
whether the offsets has been distributed yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, could you add a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the current design, its not possible to add negative tests

loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
// Update the layout_result_0 attribute to drop sg_layout and sg_data.
if (auto layoutAttr =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use the getLayoutAttr and setLayoutAttr interface here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

store doesn't have a result, and the layout is attached to the op, so not sure if getters/setters will work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then, what is the purpose of setting "layout_result_0" here?

Copy link
Contributor Author

@nbpatel nbpatel Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is to retain lane layout & lane data if its present...maybe we need to name it differently?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since StoreScatterOp doesn't have result. There should be no layout_result_0 for this op from the input code. but it could have layout_operand_0 instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to "layout", similar to load/store matrix ops

auto chunkSizeOpt = op.getChunkSize();
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
for (auto [val, offs, mask] : llvm::zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same check for offsets as above.

auto chunkSizeAttr =
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, could you add a test for this?

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks for the efforts.

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

// The offsets need to be distributed
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would read better if we do dyn_cast before the if. and same for the following one.

@nbpatel nbpatel merged commit fdfc751 into llvm:main Sep 1, 2025
9 checks passed
@nbpatel nbpatel deleted the xegpu-scatter-ops branch September 25, 2025 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants