diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index f1bed70253ef3..6d04ee5599a23 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -243,8 +243,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { ); let builders = [ - AttrBuilder<(ins "llvm::ArrayRef": $lane_layout, - "llvm::ArrayRef": $lane_data), + AttrBuilder<(ins "llvm::ArrayRef": $lane_layout, + "llvm::ArrayRef": $lane_data), [{ auto sg_layout = DenseI32ArrayAttr(); auto sg_data = DenseI32ArrayAttr(); @@ -253,6 +253,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { return $_get($_ctxt, sg_layout, sg_data, inst_data, DenseI32ArrayAttr::get($_ctxt, lane_layout), DenseI32ArrayAttr::get($_ctxt, lane_data), order); + }]>, + AttrBuilder<(ins "llvm::ArrayRef": $lane_layout, + "llvm::ArrayRef": $lane_data, + "llvm::ArrayRef": $order), + [{ + return $_get($_ctxt, + /*sg_layout =*/ nullptr, + /*sg_data =*/ nullptr, + /*inst_data =*/ nullptr, + DenseI32ArrayAttr::get($_ctxt, lane_layout), + DenseI32ArrayAttr::get($_ctxt, lane_data), + DenseI32ArrayAttr::get($_ctxt, order)); + }]>, + AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout, + "DenseI32ArrayAttr": $lane_data, + "DenseI32ArrayAttr": $order), + [{ + return $_get($_ctxt, /*sg_layout =*/ nullptr, /*sg_data =*/ nullptr, + /*inst_data =*/ nullptr, lane_layout, lane_data, order); }]> ]; @@ -262,7 +281,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { } bool isSgLayout() { - return getSgLayout() == nullptr && getLaneLayout() != nullptr; + return !isWgLayout(); } int64_t getRank() { @@ -274,6 +293,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { return attr.size(); return 0; } + + LayoutAttr dropSgLayoutAndData() { + return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(), + getLaneLayout(), getLaneData(), getOrder()); + } + + LayoutAttr dropInstData() { + return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr, + getLaneLayout(), getLaneData(), getOrder()); + } + }]; let assemblyFormat = "`<` struct(params) `>`"; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 5fa18754305ca..627de858d94aa 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface OpBuilder<(ins "Type": $tdesc, "TypedValue": $source, "llvm::ArrayRef": $offsets)>, - OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, - "llvm::ArrayRef": $offsets, - "llvm::ArrayRef": $shape, - "llvm::ArrayRef": $strides)>, - - OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, + OpBuilder<(ins "Type": $tdesc, "Value": $source, "llvm::ArrayRef": $offsets, "llvm::ArrayRef": $shape, "llvm::ArrayRef": $strides)> diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index f9d7e013826ed..f2cfa50e102f8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue source, + Type tdesc, Value source, llvm::ArrayRef offsets, llvm::ArrayRef shape, llvm::ArrayRef strides) { assert(shape.size() && offsets.size() && strides.size() && shape.size() == strides.size() && shape.size() == offsets.size()); - llvm::SmallVector staticOffsets; - llvm::SmallVector staticShape; - llvm::SmallVector staticStrides; + Type srcTy = source.getType(); + assert(isa(srcTy) || + isa(srcTy) && "Source has to be either int or memref."); + llvm::SmallVector dynamicOffsets; llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); - auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); - - build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, - dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); -} - -void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue source, - llvm::ArrayRef offsets, - llvm::ArrayRef shape, - llvm::ArrayRef strides) { - assert(shape.size() && offsets.size() && strides.size() && - shape.size() == strides.size() && shape.size() == offsets.size()); - llvm::SmallVector staticOffsets; llvm::SmallVector staticShape; llvm::SmallVector staticStrides; - llvm::SmallVector dynamicOffsets; - llvm::SmallVector dynamicShape; - llvm::SmallVector dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); @@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + if (auto memrefTy = dyn_cast(srcTy)) { + auto memrefShape = memrefTy.getShape(); + auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); + + // if shape and strides are from Memref, we don't need attributes for them + // to keep the IR print clean. + if (staticShape == memrefShape && staticStrides == memrefStrides) { + staticShapeAttr = DenseI64ArrayAttr(); + staticStridesAttr = DenseI64ArrayAttr(); + } + } + build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); }