Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][xegpu] refine basic routines #138701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 9, 2025

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented May 6, 2025

This PR adds two interfaces for LayoutAttr and updates the builder of CreateNdOp for convenience.

@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

This PR adds two interfaces for LayoutAttr and updates the builder of CreateNdOp for convenience.


Full diff: https://github.com/llvm/llvm-project/pull/138701.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+34-1)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+16-27)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..b4236af497587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+      }]>,
+    AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
+                     "llvm::ArrayRef<int>": $lane_data,
+                     "llvm::ArrayRef<int>": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     DenseI32ArrayAttr::get($_ctxt, lane_layout),
+                     DenseI32ArrayAttr::get($_ctxt, lane_data),
+                     DenseI32ArrayAttr::get($_ctxt, order));
+      }]>,
+    AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+                     "DenseI32ArrayAttr": $lane_data,
+                     "DenseI32ArrayAttr": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     lane_layout, lane_data, order);
       }]>
   ];
 
@@ -262,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     bool isSgLayout() {
-      return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+      return !isWgLayout();
     }
 
     int64_t getRank() {
@@ -274,6 +296,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return attr.size();
       return 0;
     }
+
+    LayoutAttr dropSgLayoutAndData() {
+      return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
+    LayoutAttr dropInstData() {
+      return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..627de858d94aa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..7df6c794615a1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> offsets,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
   assert(shape.size() && offsets.size() && strides.size() &&
          shape.size() == strides.size() && shape.size() == offsets.size());
 
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
+  auto intTy = dyn_cast<IntegerType>(source.getType());
+  auto memrefTy = dyn_cast<MemRefType>(source.getType());
+  assert(intTy || memrefTy && "Source has to be either int or memref.");
+
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
-        dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> offsets,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && offsets.size() && strides.size() &&
-         shape.size() == strides.size() && shape.size() == offsets.size());
-
   llvm::SmallVector<int64_t> staticOffsets;
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (memrefTy) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
 }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits

@chencha3 chencha3 merged commit 20d6def into main May 9, 2025
11 checks passed
@chencha3 chencha3 deleted the users/chencha3/xegpu/refine_xegpu_basic_routines branch May 9, 2025 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants