Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 715492e

Browse files
authored
[MLIR][XeGPU] Add wg-to-sg distirbution for dpasmx, bitcast, interleave, and deinterleave (#194985)
As title. Assisted by Claude
1 parent c02e49a commit 715492e

4 files changed

Lines changed: 281 additions & 42 deletions

File tree

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,10 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
127127
// For regular operations: First the result layouts are propagated from uses.
128128
// Then the result layouts are propagated to uses (operands).
129129
static void propagateResultsToRegularOperands(Operation *op) {
130-
if (op->getNumResults() == 0 || op->getNumResults() > 1)
130+
if (op->getNumResults() == 0)
131+
return;
132+
if (op->getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
131133
return;
132-
133134
OpResult result = op->getResult(0);
134135
xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
135136
Type resultType = result.getType();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 169 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
325325
}
326326

327327
ArrayRef<int64_t> aVecShape =
328-
llvm::cast<VectorType>(aVec.getType()).getShape();
328+
cast<VectorType>(aVec.getType()).getShape();
329329
ArrayRef<int64_t> bVecShape =
330-
llvm::cast<VectorType>(bVec.getType()).getShape();
330+
cast<VectorType>(bVec.getType()).getShape();
331331
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
332332
resultTy.getElementType());
333333
auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
@@ -343,6 +343,58 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
343343
}
344344
};
345345

346+
/// This pattern transforms the DpasMxOp to work at subgroup level.
347+
struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
348+
using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
349+
LogicalResult
350+
matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
351+
ConversionPatternRewriter &rewriter) const override {
352+
353+
Location loc = op.getLoc();
354+
VectorType resultTy = op.getResult().getType();
355+
356+
if (resultTy.getRank() != 2)
357+
return failure();
358+
359+
auto layoutCd = op.getLayoutCdAttr();
360+
auto layoutA = op.getLayoutAAttr();
361+
auto layoutB = op.getLayoutBAttr();
362+
auto layoutAScale = op.getLayoutAScaleAttr();
363+
auto layoutBScale = op.getLayoutBScaleAttr();
364+
365+
if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
366+
return failure();
367+
368+
size_t index_c = 0;
369+
SmallVector<Value> newDpasMxOps;
370+
for (auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
371+
for (auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
372+
Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
373+
Value scaleAVal =
374+
(op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
375+
Value scaleBVal =
376+
(op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
377+
378+
ArrayRef<int64_t> aVecShape =
379+
cast<VectorType>(aVec.getType()).getShape();
380+
ArrayRef<int64_t> bVecShape =
381+
cast<VectorType>(bVec.getType()).getShape();
382+
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
383+
resultTy.getElementType());
384+
auto newDpasMxOp = xegpu::DpasMxOp::create(
385+
rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
386+
layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
387+
layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
388+
layoutBScale.dropSgLayoutAndData());
389+
390+
newDpasMxOps.push_back(newDpasMxOp);
391+
}
392+
}
393+
rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
394+
return success();
395+
}
396+
};
397+
346398
/// This pattern transforms vector.broadcast ops to work at subgroup level.
347399
struct WgToSgVectorBroadcastOp
348400
: public OpConversionPattern<vector::BroadcastOp> {
@@ -1403,19 +1455,111 @@ struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
14031455

14041456
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
14051457
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1458+
1459+
// This pattern transforms vector.bitcast ops to work at subgroup level.
1460+
struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
1461+
using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1462+
1463+
LogicalResult
1464+
matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1465+
ConversionPatternRewriter &rewriter) const override {
1466+
VectorType resultType = op.getResultVectorType();
1467+
1468+
ArrayRef<int64_t> wgShape = resultType.getShape();
1469+
xegpu::DistributeLayoutAttr layout =
1470+
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1471+
if (!layout || !layout.isForWorkgroup())
1472+
return failure();
1473+
1474+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1475+
VectorType newResultType =
1476+
VectorType::get(sgShape, resultType.getElementType());
1477+
1478+
SmallVector<Value> newBitCastOps;
1479+
for (auto src : adaptor.getSource()) {
1480+
auto newBitCast =
1481+
vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1482+
newBitCastOps.push_back(newBitCast.getResult());
1483+
}
1484+
1485+
rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1486+
return success();
1487+
}
1488+
};
1489+
1490+
// This pattern transforms vector.interleave ops to work at subgroup level.
1491+
struct WgToSgVectorInterleaveOp
1492+
: public OpConversionPattern<vector::InterleaveOp> {
1493+
using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1494+
1495+
LogicalResult
1496+
matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1497+
ConversionPatternRewriter &rewriter) const override {
1498+
VectorType resultType = op.getResultVectorType();
1499+
1500+
ArrayRef<int64_t> wgShape = resultType.getShape();
1501+
xegpu::DistributeLayoutAttr layout =
1502+
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1503+
if (!layout || !layout.isForWorkgroup())
1504+
return failure();
1505+
1506+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1507+
VectorType newResultType =
1508+
VectorType::get(sgShape, resultType.getElementType());
1509+
1510+
SmallVector<Value> newInterleaveOps;
1511+
// Interleave operates pairwise: each lhs value is interleaved with
1512+
// corresponding rhs value
1513+
for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1514+
auto newInterleave = vector::InterleaveOp::create(
1515+
rewriter, op.getLoc(), newResultType, lhs, rhs);
1516+
newInterleaveOps.push_back(newInterleave.getResult());
1517+
}
1518+
1519+
rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1520+
return success();
1521+
}
1522+
};
1523+
1524+
// This pattern transforms vector.deinterleave ops to work at subgroup level.
1525+
struct WgToSgVectorDeinterleaveOp
1526+
: public OpConversionPattern<vector::DeinterleaveOp> {
1527+
using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1528+
1529+
LogicalResult
1530+
matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1531+
ConversionPatternRewriter &rewriter) const override {
1532+
SmallVector<Value> newRes1Ops;
1533+
SmallVector<Value> newRes2Ops;
1534+
1535+
for (auto src : adaptor.getSource()) {
1536+
auto newDeinterleave =
1537+
vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1538+
newRes1Ops.push_back(newDeinterleave.getRes1());
1539+
newRes2Ops.push_back(newDeinterleave.getRes2());
1540+
}
1541+
1542+
SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1543+
rewriter.replaceOpWithMultiple(op, results);
1544+
return success();
1545+
}
1546+
};
1547+
14061548
} // namespace
14071549

14081550
namespace mlir {
14091551
namespace xegpu {
14101552
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
14111553
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1412-
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
1413-
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
1414-
WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOp,
1415-
WgToSgStoreScatterOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
1416-
WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1417-
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1418-
WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1554+
WgToSgDpasMxOp, WgToSgPrefetchNdOp,
1555+
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
1556+
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1557+
WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1558+
WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1559+
WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1560+
WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1561+
WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1562+
WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
14191563
patterns.getContext());
14201564
}
14211565
} // namespace xegpu
@@ -1539,6 +1683,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
15391683
return isLegal(layout);
15401684
});
15411685

1686+
target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1687+
[=](xegpu::DpasMxOp op) -> bool {
1688+
auto layout = op.getLayoutCdAttr();
1689+
return isLegal(layout);
1690+
});
1691+
15421692
target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
15431693
[=](xegpu::LoadMatrixOp op) -> bool {
15441694
return isLegal(op.getLayoutAttr());
@@ -1560,16 +1710,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
15601710
return isLegal(layout);
15611711
});
15621712

1563-
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1564-
vector::TransposeOp, vector::BroadcastOp,
1565-
vector::MultiDimReductionOp,
1566-
vector::ConstantMaskOp, vector::CreateMaskOp>(
1567-
[=](Operation *op) -> bool {
1568-
// Check for either a SliceAttr or LayoutAttr on the result.
1569-
auto layout =
1570-
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1571-
return isLegal(layout);
1572-
});
1713+
target.addDynamicallyLegalOp<
1714+
vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1715+
vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1716+
vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1717+
vector::DeinterleaveOp>([=](Operation *op) -> bool {
1718+
// Check for either a SliceAttr or LayoutAttr on the result.
1719+
auto layout =
1720+
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1721+
return isLegal(layout);
1722+
});
15731723

15741724
target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
15751725
[=](xegpu::LoadGatherOp op) -> bool {

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -185,29 +185,28 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
185185
}
186186
if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
187187
// DpasMxOp has operands: a, b, optional acc, optional scale_a, optional
188-
// scale_b Use AttrSizedOperandSegments to determine which operand this is
189-
auto segmentSizesAttr = dpasMxOp->getAttrOfType<DenseI32ArrayAttr>(
190-
dpasMxOp.getOperandSegmentSizesAttrName());
191-
if (!segmentSizesAttr)
192-
return nullptr;
193-
194-
auto segmentSizes = segmentSizesAttr.asArrayRef();
195-
unsigned aSize = segmentSizes[0];
196-
unsigned bSize = segmentSizes[1];
197-
unsigned accSize = segmentSizes[2];
198-
unsigned scaleASize = segmentSizes[3];
199-
200-
if (idx < aSize) {
188+
// scale_b
189+
unsigned currentIdx = 0;
190+
191+
if (idx == currentIdx++)
201192
return dpasMxOp.getLayoutAAttr();
202-
} else if (idx < aSize + bSize) {
193+
194+
if (idx == currentIdx++)
203195
return dpasMxOp.getLayoutBAttr();
204-
} else if (idx < aSize + bSize + accSize) {
205-
return dpasMxOp.getLayoutCdAttr();
206-
} else if (idx < aSize + bSize + accSize + scaleASize) {
207-
return dpasMxOp.getLayoutAScaleAttr();
208-
} else {
209-
return dpasMxOp.getLayoutBScaleAttr();
210-
}
196+
197+
if (dpasMxOp.getAcc())
198+
if (idx == currentIdx++)
199+
return dpasMxOp.getLayoutCdAttr();
200+
201+
if (dpasMxOp.getScaleA())
202+
if (idx == currentIdx++)
203+
return dpasMxOp.getLayoutAScaleAttr();
204+
205+
if (dpasMxOp.getScaleB())
206+
if (idx == currentIdx++)
207+
return dpasMxOp.getLayoutBScaleAttr();
208+
209+
return nullptr;
211210
}
212211
if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
213212
return convertOp.getInputLayoutAttr();

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,40 @@ gpu.module @test_distribution {
9494
gpu.return
9595
}
9696

97+
// CHECK-LABEL: dpas_mx
98+
gpu.func @dpas_mx(%a: memref<128x128xf8E5M2>, %b: memref<128x128xf8E5M2>, %a_scale: memref<128x4xf8E8M0FNU>, %b_scale: memref<4x128xf8E8M0FNU>) {
99+
// CHECK: %[[DPAS_MX:.*]] = xegpu.dpas_mx %{{.*}}, %{{.*}}, %{{.*}} scale_a = %{{.*}} scale_b = %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_a_scale = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf8E5M2>, vector<128x16xf8E5M2>, vector<16x16xbf16>, vector<16x4xf8E8M0FNU>, vector<4x16xf8E8M0FNU> -> vector<16x16xbf16>
100+
%tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf8E5M2>
101+
-> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>>
102+
%load_a = xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>}
103+
: !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>>
104+
-> vector<128x128xf8E5M2>
105+
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf8E5M2>
106+
-> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
107+
%load_b = xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>}
108+
: !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
109+
-> vector<128x128xf8E5M2>
110+
%tdesc_a_scale = xegpu.create_nd_tdesc %a_scale : memref<128x4xf8E8M0FNU>
111+
-> !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>>
112+
%load_a_scale = xegpu.load_nd %tdesc_a_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>}
113+
: !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>>
114+
-> vector<128x4xf8E8M0FNU>
115+
%tdesc_b_scale = xegpu.create_nd_tdesc %b_scale : memref<4x128xf8E8M0FNU>
116+
-> !xegpu.tensor_desc<4x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
117+
%load_b_scale = xegpu.load_nd %tdesc_b_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
118+
: !xegpu.tensor_desc<4x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
119+
-> vector<4x128xf8E8M0FNU>
120+
%cst = arith.constant dense<0.0> : vector<128x128xbf16>
121+
%dpas_mx = xegpu.dpas_mx %load_a, %load_b, %cst scale_a = %load_a_scale scale_b = %load_b_scale
122+
{layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>,
123+
layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>,
124+
layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
125+
layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>,
126+
layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
127+
: vector<128x128xf8E5M2>, vector<128x128xf8E5M2>, vector<128x128xbf16>, vector<128x4xf8E8M0FNU>, vector<4x128xf8E8M0FNU> -> vector<128x128xbf16>
128+
gpu.return
129+
}
130+
97131
// CHECK-LABEL: dpas_no_sg_data
98132
gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
99133
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
@@ -1266,6 +1300,61 @@ gpu.module @test_distribution {
12661300
gpu.return
12671301
}
12681302

1303+
// CHECK-LABEL: @bitcast_distribution
1304+
gpu.func @bitcast_distribution(%src: memref<256x128xf32>) {
1305+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
1306+
-> !xegpu.tensor_desc<256x128xf32>
1307+
%load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
1308+
: !xegpu.tensor_desc<256x128xf32>
1309+
-> vector<256x128xf32>
1310+
// CHECK: vector.bitcast {{.*}} : vector<32x32xf32> to vector<32x64xi16>
1311+
%bitcast = vector.bitcast %load : vector<256x128xf32> to vector<256x256xi16>
1312+
%add = arith.addi %bitcast, %bitcast : vector<256x256xi16>
1313+
// CHECK: vector.bitcast {{.*}} : vector<32x64xi16> to vector<32x32xi32>
1314+
%bitcast2 = vector.bitcast %add : vector<256x256xi16> to vector<256x128xi32>
1315+
%anchor = xegpu.convert_layout %bitcast2
1316+
<{
1317+
input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>,
1318+
target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
1319+
}> : vector<256x128xi32>
1320+
gpu.return
1321+
}
1322+
1323+
// CHECK-LABEL: @interleave_distribution
1324+
gpu.func @interleave_distribution(%src: memref<256x128xf32>) {
1325+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
1326+
-> !xegpu.tensor_desc<256x128xf32>
1327+
%load1 = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
1328+
: !xegpu.tensor_desc<256x128xf32>
1329+
-> vector<256x128xf32>
1330+
%load2 = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
1331+
: !xegpu.tensor_desc<256x128xf32>
1332+
-> vector<256x128xf32>
1333+
// CHECK: vector.interleave {{.*}}, {{.*}} : vector<32x32xf32> -> vector<32x64xf32>
1334+
%interleave = vector.interleave %load1, %load2
1335+
: vector<256x128xf32> -> vector<256x256xf32>
1336+
%anchor = xegpu.convert_layout %interleave
1337+
<{
1338+
input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>,
1339+
target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>
1340+
}> : vector<256x256xf32>
1341+
gpu.return
1342+
}
1343+
1344+
// CHECK-LABEL: @deinterleave_distribution
1345+
gpu.func @deinterleave_distribution(%src: memref<256x256xf32>) {
1346+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32>
1347+
%load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>} : !xegpu.tensor_desc<256x256xf32> -> vector<256x256xf32>
1348+
// CHECK: {{.*}} = vector.deinterleave {{.*}} : vector<32x64xf32> -> vector<32x32xf32>
1349+
%deinterleave:2 = vector.deinterleave %load : vector<256x256xf32> -> vector<256x128xf32>
1350+
%anchor = xegpu.convert_layout %deinterleave#0
1351+
<{
1352+
input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>,
1353+
target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
1354+
}> : vector<256x128xf32>
1355+
gpu.return
1356+
}
1357+
12691358
}
12701359

12711360
// -----

0 commit comments

Comments
 (0)