Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9b0d7dd

Browse files
authored
[mlir][xegpu] Add support for vector.multi_reduction and vector.shape_cast SIMT distribution. (llvm#157560)
Add support for distributing the `vector.multi_reduction` operation across lanes in a warp. Currently only 2D to 1D reductions are supported. Given layouts for the source and accumulator vectors, * If the reduction dimension is distributed across lanes, the reduction is non-lane-local and the reduction is done using warp shuffles. Here we simply rewrite the `MultiDimReductionOp` to a sequence of `ReductionOp`s inside the warp op body. Actual distribution will be done by `WarpOpReduction` pattern. * If the reduction dimension is not distributed across lanes, the reduction is lane-local. In this case, we yield the source and accumulator vectors from the warp op and perform the lane-local reduction outside the warp op using a sequence of `ReductionOp`s. PR also adds support for distributing `vector.shape_cast` based on layouts.
1 parent 5d088ba commit 9b0d7dd

File tree

8 files changed

+530
-72
lines changed

8 files changed

+530
-72
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,29 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
194194
InterfaceMethod<"Get the num of effective subgroups",
195195
"int64_t",
196196
"getNumSubgroups", (ins), [{
197-
std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
197+
std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getEffectiveSgLayoutAsInt();
198198
if (sgLayout.has_value())
199199
return computeProduct(*sgLayout);
200200
return 0;
201201
}], [{}]>,
202-
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
202+
InterfaceMethod<"Get the order of the layout attribute",
203+
"DenseI32ArrayAttr",
204+
"getOrder">,
205+
InterfaceMethod<"Get the effective SgLayout of the layout attribute as integer array",
203206
"SmallVector<int64_t>",
204-
"getSgLayoutAsInt">,
205-
InterfaceMethod<"Get the SgData field of the attribute as integer array",
207+
"getEffectiveSgLayoutAsInt">,
208+
InterfaceMethod<"Get the effective SgData of the layout attribute as integer array",
206209
"SmallVector<int64_t>",
207-
"getSgDataAsInt">,
208-
InterfaceMethod<"Get the InstData field of the attribute as integer array",
210+
"getEffectiveSgDataAsInt">,
211+
InterfaceMethod<"Get the effective InstData of the layout attribute as integer array",
209212
"SmallVector<int64_t>",
210-
"getInstDataAsInt">,
211-
InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
213+
"getEffectiveInstDataAsInt">,
214+
InterfaceMethod<"Get the effective LaneLayout of the layout attribute as integer array",
212215
"SmallVector<int64_t>",
213-
"getLaneLayoutAsInt">,
214-
InterfaceMethod<"Get the LaneData field of the attribute as integer array",
216+
"getEffectiveLaneLayoutAsInt">,
217+
InterfaceMethod<"Get the effective LaneData of the layout attribute as integer array",
215218
"SmallVector<int64_t>",
216-
"getLaneDataAsInt">,
219+
"getEffectiveLaneDataAsInt">,
217220
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
218221
"xegpu::DistributeLayoutAttr",
219222
"dropSgLayoutAndData">,
@@ -231,7 +234,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
231234
multiple blocks according to round-robin distribution rules.}],
232235
"FailureOr<SmallVector<SmallVector<Value>>>",
233236
"getOffsets",
234-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
237+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
238+
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
239+
/*retTy=*/"bool",
240+
/*methodName=*/"isSliceOf",
241+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
235242
];
236243
}
237244

@@ -391,31 +398,31 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
391398
getLaneLayout(), getLaneData(), getOrder());
392399
}
393400

394-
SmallVector<int64_t> getSgLayoutAsInt() const {
401+
SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
395402
if (DenseI32ArrayAttr layout = getSgLayout())
396403
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
397404
return {};
398405
}
399406

400-
SmallVector<int64_t> getSgDataAsInt() const {
407+
SmallVector<int64_t> getEffectiveSgDataAsInt() const {
401408
if (DenseI32ArrayAttr data = getSgData())
402409
return llvm::to_vector_of<int64_t>(data.asArrayRef());
403410
return {};
404411
}
405412

406-
SmallVector<int64_t> getInstDataAsInt() const {
413+
SmallVector<int64_t> getEffectiveInstDataAsInt() const {
407414
if (DenseI32ArrayAttr inst = getInstData())
408415
return llvm::to_vector_of<int64_t>(inst.asArrayRef());
409416
return {};
410417
}
411418

412-
SmallVector<int64_t> getLaneLayoutAsInt() const {
419+
SmallVector<int64_t> getEffectiveLaneLayoutAsInt() const {
413420
if (DenseI32ArrayAttr layout = getLaneLayout())
414421
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
415422
return {};
416423
}
417424

418-
SmallVector<int64_t> getLaneDataAsInt() const {
425+
SmallVector<int64_t> getEffectiveLaneDataAsInt() const {
419426
if (DenseI32ArrayAttr data = getLaneData())
420427
return llvm::to_vector_of<int64_t>(data.asArrayRef());
421428
return {};
@@ -433,6 +440,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
433440
FailureOr<SmallVector<SmallVector<Value>>>
434441
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
435442

443+
/// Check if this is slice of some other layout.
444+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
445+
436446
}];
437447

438448
let assemblyFormat = "`<` struct(params) `>`";
@@ -499,10 +509,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
499509

500510
/// Returns the SgLayout of the attribute, computed by applying
501511
/// the slice dimensions to the underlying LayoutAttr.
502-
SmallVector<int64_t> getSgLayoutAsInt() const {
512+
SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
503513
SliceAttr attr = flatten();
504514
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
505-
auto layout = parent.getSgLayoutAsInt();
515+
auto layout = parent.getEffectiveSgLayoutAsInt();
506516
if (layout.size()) {
507517
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
508518
return XeGPUDialect::slice(ArrayRef<int64_t>(layout), dims);
@@ -512,10 +522,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
512522

513523
/// Returns the SgData of the attribute, computed by applying
514524
/// the slice dimensions to the underlying LayoutAttr.
515-
SmallVector<int64_t> getSgDataAsInt() const {
525+
SmallVector<int64_t> getEffectiveSgDataAsInt() const {
516526
SliceAttr attr = flatten();
517527
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
518-
auto data = parent.getSgDataAsInt();
528+
auto data = parent.getEffectiveSgDataAsInt();
519529
if (data.size()) {
520530
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
521531
return XeGPUDialect::slice(ArrayRef<int64_t>(data), dims);
@@ -525,10 +535,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
525535

526536
/// Returns the InstData of the attribute, computed by applying
527537
/// the slice dimensions to the underlying LayoutAttr.
528-
SmallVector<int64_t> getInstDataAsInt() const {
538+
SmallVector<int64_t> getEffectiveInstDataAsInt() const {
529539
SliceAttr attr = flatten();
530540
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
531-
auto inst = parent.getInstDataAsInt();
541+
auto inst = parent.getEffectiveInstDataAsInt();
532542
if (inst.size()) {
533543
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
534544
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(inst), dims);
@@ -538,10 +548,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
538548

539549
/// Returns the LaneLayout of the attribute, computed by applying
540550
/// the slice dimensions to the underlying LayoutAttr.
541-
SmallVector<int64_t> getLaneLayoutAsInt() const {
551+
SmallVector<int64_t> getEffectiveLaneLayoutAsInt() const {
542552
SliceAttr attr = flatten();
543553
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
544-
auto layout = parent.getLaneLayoutAsInt();
554+
auto layout = parent.getEffectiveLaneLayoutAsInt();
545555
if (layout.size()) {
546556
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
547557
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(layout), dims);
@@ -551,10 +561,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
551561

552562
/// Returns the LaneData of the attribute, computed by applying
553563
/// the slice dimensions to the underlying LayoutAttr.
554-
SmallVector<int64_t> getLaneDataAsInt() const {
564+
SmallVector<int64_t> getEffectiveLaneDataAsInt() const {
555565
SliceAttr attr = flatten();
556566
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
557-
auto data = parent.getLaneDataAsInt();
567+
auto data = parent.getEffectiveLaneDataAsInt();
558568
if (data.size()) {
559569
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
560570
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(data), dims);
@@ -594,6 +604,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
594604
FailureOr<SmallVector<SmallVector<Value>>>
595605
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
596606

607+
/// Check if this is slice of some other layout.
608+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
609+
597610
}];
598611

599612
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
2727
}];
2828
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
2929
"vector::VectorDialect"];
30+
let options = [Option<
31+
"enableSGReductions", "enable-sg-reductions", "bool",
32+
/*default=*/"true",
33+
"Enable subgroup reductions using subgroup shuffles.">];
3034
}
3135

3236
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,22 +133,23 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
133133
};
134134

135135
// check the sgLayout and sgData
136-
auto maybeSgShape =
137-
tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
136+
auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
137+
attr.getEffectiveSgDataAsInt());
138138
if (!maybeSgShape)
139139
return false;
140140
auto sgShape = maybeSgShape.value();
141141

142142
// check InstData, it neither have layout nor need round-robin
143143
auto maybeInstShape =
144-
tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false);
144+
tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
145145
if (!maybeInstShape)
146146
return false;
147147
auto instShape = maybeInstShape.value();
148148

149149
// check LaneLayout and LaneData
150-
auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
151-
attr.getLaneDataAsInt(), false);
150+
auto maybeLaneShape =
151+
tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
152+
attr.getEffectiveLaneDataAsInt(), false);
152153
return maybeLaneShape.has_value();
153154
}
154155

@@ -282,9 +283,10 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
282283
if (!hasDefaultOrder())
283284
return mlir::emitError(loc, "order attribute is currently not supported.");
284285

285-
auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
286-
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
287-
});
286+
auto dims =
287+
llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
288+
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
289+
});
288290

289291
return affine::delinearizeIndex(builder, loc, linearId, dims);
290292
}
@@ -298,8 +300,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
298300
if (!isForWorkgroup())
299301
return failure();
300302

301-
SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
302-
SmallVector<int64_t> sgShape = getSgDataAsInt();
303+
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
304+
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
303305
if (sgShape.empty()) {
304306
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
305307
sgShape = derivedShape.value();
@@ -385,8 +387,8 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
385387
if (!isForWorkgroup())
386388
return failure();
387389

388-
SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
389-
SmallVector<int64_t> sgShape = getSgDataAsInt();
390+
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
391+
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
390392
if (sgShape.empty()) {
391393
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
392394
sgShape = derivedShape.value();
@@ -409,6 +411,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
409411
shape);
410412
}
411413

414+
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
415+
auto flattenedThis = flatten();
416+
// If other is a LayoutAttr, just compare directly with parent of
417+
// flattenedThis.
418+
if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
419+
return flattenedThis.getParent() == otherLayout;
420+
// If other is a SliceAttr, flatten it first before comparing.
421+
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
422+
// Both must have common parent LayoutAttr.
423+
if (flattenedThis.getParent() != flattenedOther.getParent())
424+
return false;
425+
// otherFlattened's sliced dims must be a subset of flattenedThis's sliced
426+
// dims.
427+
llvm::SmallDenseSet<int64_t> thisDims(
428+
flattenedThis.getDims().asArrayRef().begin(),
429+
flattenedThis.getDims().asArrayRef().end());
430+
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
431+
[&](int64_t dim) { return thisDims.contains(dim); });
432+
}
433+
412434
//===----------------------------------------------------------------------===//
413435
// XeGPU_RangeAttr
414436
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ struct ConvertLayoutOpPattern
8585
using OpRewritePattern::OpRewritePattern;
8686
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
8787
PatternRewriter &rewriter) const override {
88-
xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
89-
xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
90-
if (input_layout.getInstDataAsInt().empty() ||
91-
target_layout.getInstDataAsInt().empty())
88+
xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
89+
xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
90+
if (inputLayout.getEffectiveInstDataAsInt().empty() ||
91+
targetLayout.getEffectiveInstDataAsInt().empty())
9292
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
9393

94-
input_layout = input_layout.dropInstData();
95-
target_layout = target_layout.dropInstData();
94+
inputLayout = inputLayout.dropInstData();
95+
targetLayout = targetLayout.dropInstData();
9696
auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
97-
op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
97+
op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
9898
rewriter.replaceOp(op, newOp);
9999
return success();
100100
}
@@ -145,8 +145,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
145145
xegpu::DistributeLayoutAttr layout =
146146
xegpu::getDistributeLayoutAttr(operandOrResult);
147147
if (layout && layout.isForSubgroup()) {
148-
if (!layout.getInstDataAsInt().empty())
149-
return layout.getInstDataAsInt();
148+
if (!layout.getEffectiveInstDataAsInt().empty())
149+
return layout.getEffectiveInstDataAsInt();
150150

151151
if (auto type = dyn_cast<ShapedType>(value.getType()))
152152
return llvm::to_vector(type.getShape());
@@ -226,7 +226,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
226226
Type valTy = value.getType();
227227
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
228228
xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
229-
return layout && !layout.getInstDataAsInt().empty();
229+
return layout && !layout.getEffectiveInstDataAsInt().empty();
230230
}
231231
auto shapedType = dyn_cast<ShapedType>(valTy);
232232
return shapedType && !llvm::equal(tileShape, shapedType.getShape());

0 commit comments

Comments
 (0)