Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][xegpu] refine basic routines #138701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented May 6, 2025

This PR adds two interfaces for LayoutAttr and updates the builder of CreateNdOp for convenience.

@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

This PR adds two interfaces for LayoutAttr and updates the builder of CreateNdOp for convenience.


Full diff: https://github.com/llvm/llvm-project/pull/138701.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+34-1)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+16-27)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..b4236af497587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+      }]>,
+    AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
+                     "llvm::ArrayRef<int>": $lane_data,
+                     "llvm::ArrayRef<int>": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     DenseI32ArrayAttr::get($_ctxt, lane_layout),
+                     DenseI32ArrayAttr::get($_ctxt, lane_data),
+                     DenseI32ArrayAttr::get($_ctxt, order));
+      }]>,
+    AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+                     "DenseI32ArrayAttr": $lane_data,
+                     "DenseI32ArrayAttr": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     lane_layout, lane_data, order);
       }]>
   ];
 
@@ -262,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     bool isSgLayout() {
-      return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+      return !isWgLayout();
     }
 
     int64_t getRank() {
@@ -274,6 +296,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return attr.size();
       return 0;
     }
+
+    LayoutAttr dropSgLayoutAndData() {
+      return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
+    LayoutAttr dropInstData() {
+      return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..627de858d94aa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..7df6c794615a1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> offsets,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
   assert(shape.size() && offsets.size() && strides.size() &&
          shape.size() == strides.size() && shape.size() == offsets.size());
 
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
+  auto intTy = dyn_cast<IntegerType>(source.getType());
+  auto memrefTy = dyn_cast<MemRefType>(source.getType());
+  assert(intTy || memrefTy && "Source has to be either int or memref.");
+
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
-        dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> offsets,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && offsets.size() && strides.size() &&
-         shape.size() == strides.size() && shape.size() == offsets.size());
-
   llvm::SmallVector<int64_t> staticOffsets;
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (memrefTy) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants