@@ -325,9 +325,9 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
325325 }
326326
327327 ArrayRef<int64_t > aVecShape =
328- llvm:: cast<VectorType>(aVec.getType ()).getShape ();
328+ cast<VectorType>(aVec.getType ()).getShape ();
329329 ArrayRef<int64_t > bVecShape =
330- llvm:: cast<VectorType>(bVec.getType ()).getShape ();
330+ cast<VectorType>(bVec.getType ()).getShape ();
331331 VectorType resTy = VectorType::get ({aVecShape[0 ], bVecShape[1 ]},
332332 resultTy.getElementType ());
333333 auto newDpasOp = xegpu::DpasOp::create (rewriter, loc, resTy, operands);
@@ -343,6 +343,58 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
343343 }
344344};
345345
346+ // / This pattern transforms the DpasMxOp to work at subgroup level.
347+ struct WgToSgDpasMxOp : public OpConversionPattern <xegpu::DpasMxOp> {
348+ using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
349+ LogicalResult
350+ matchAndRewrite (xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
351+ ConversionPatternRewriter &rewriter) const override {
352+
353+ Location loc = op.getLoc ();
354+ VectorType resultTy = op.getResult ().getType ();
355+
356+ if (resultTy.getRank () != 2 )
357+ return failure ();
358+
359+ auto layoutCd = op.getLayoutCdAttr ();
360+ auto layoutA = op.getLayoutAAttr ();
361+ auto layoutB = op.getLayoutBAttr ();
362+ auto layoutAScale = op.getLayoutAScaleAttr ();
363+ auto layoutBScale = op.getLayoutBScaleAttr ();
364+
365+ if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
366+ return failure ();
367+
368+ size_t index_c = 0 ;
369+ SmallVector<Value> newDpasMxOps;
370+ for (auto [index_a, aVec] : llvm::enumerate (adaptor.getA ())) {
371+ for (auto [index_b, bVec] : llvm::enumerate (adaptor.getB ())) {
372+ Value accVal = (op.getAcc ()) ? adaptor.getAcc ()[index_c++] : Value ();
373+ Value scaleAVal =
374+ (op.getScaleA ()) ? adaptor.getScaleA ()[index_a] : Value ();
375+ Value scaleBVal =
376+ (op.getScaleB ()) ? adaptor.getScaleB ()[index_b] : Value ();
377+
378+ ArrayRef<int64_t > aVecShape =
379+ cast<VectorType>(aVec.getType ()).getShape ();
380+ ArrayRef<int64_t > bVecShape =
381+ cast<VectorType>(bVec.getType ()).getShape ();
382+ VectorType resTy = VectorType::get ({aVecShape[0 ], bVecShape[1 ]},
383+ resultTy.getElementType ());
384+ auto newDpasMxOp = xegpu::DpasMxOp::create (
385+ rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
386+ layoutA.dropSgLayoutAndData (), layoutB.dropSgLayoutAndData (),
387+ layoutCd.dropSgLayoutAndData (), layoutAScale.dropSgLayoutAndData (),
388+ layoutBScale.dropSgLayoutAndData ());
389+
390+ newDpasMxOps.push_back (newDpasMxOp);
391+ }
392+ }
393+ rewriter.replaceOpWithMultiple (op, {newDpasMxOps});
394+ return success ();
395+ }
396+ };
397+
346398// / This pattern transforms vector.broadcast ops to work at subgroup level.
347399struct WgToSgVectorBroadcastOp
348400 : public OpConversionPattern<vector::BroadcastOp> {
@@ -1403,19 +1455,111 @@ struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
14031455
14041456using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
14051457using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1458+
1459+ // This pattern transforms vector.bitcast ops to work at subgroup level.
1460+ struct WgToSgVectorBitCastOp : public OpConversionPattern <vector::BitCastOp> {
1461+ using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1462+
1463+ LogicalResult
1464+ matchAndRewrite (vector::BitCastOp op, OneToNOpAdaptor adaptor,
1465+ ConversionPatternRewriter &rewriter) const override {
1466+ VectorType resultType = op.getResultVectorType ();
1467+
1468+ ArrayRef<int64_t > wgShape = resultType.getShape ();
1469+ xegpu::DistributeLayoutAttr layout =
1470+ xegpu::getTemporaryLayout (dyn_cast<OpResult>(op.getResult ()));
1471+ if (!layout || !layout.isForWorkgroup ())
1472+ return failure ();
1473+
1474+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
1475+ VectorType newResultType =
1476+ VectorType::get (sgShape, resultType.getElementType ());
1477+
1478+ SmallVector<Value> newBitCastOps;
1479+ for (auto src : adaptor.getSource ()) {
1480+ auto newBitCast =
1481+ vector::BitCastOp::create (rewriter, op.getLoc (), newResultType, src);
1482+ newBitCastOps.push_back (newBitCast.getResult ());
1483+ }
1484+
1485+ rewriter.replaceOpWithMultiple (op, {newBitCastOps});
1486+ return success ();
1487+ }
1488+ };
1489+
1490+ // This pattern transforms vector.interleave ops to work at subgroup level.
1491+ struct WgToSgVectorInterleaveOp
1492+ : public OpConversionPattern<vector::InterleaveOp> {
1493+ using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1494+
1495+ LogicalResult
1496+ matchAndRewrite (vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1497+ ConversionPatternRewriter &rewriter) const override {
1498+ VectorType resultType = op.getResultVectorType ();
1499+
1500+ ArrayRef<int64_t > wgShape = resultType.getShape ();
1501+ xegpu::DistributeLayoutAttr layout =
1502+ xegpu::getTemporaryLayout (dyn_cast<OpResult>(op.getResult ()));
1503+ if (!layout || !layout.isForWorkgroup ())
1504+ return failure ();
1505+
1506+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
1507+ VectorType newResultType =
1508+ VectorType::get (sgShape, resultType.getElementType ());
1509+
1510+ SmallVector<Value> newInterleaveOps;
1511+ // Interleave operates pairwise: each lhs value is interleaved with
1512+ // corresponding rhs value
1513+ for (auto [lhs, rhs] : llvm::zip (adaptor.getLhs (), adaptor.getRhs ())) {
1514+ auto newInterleave = vector::InterleaveOp::create (
1515+ rewriter, op.getLoc (), newResultType, lhs, rhs);
1516+ newInterleaveOps.push_back (newInterleave.getResult ());
1517+ }
1518+
1519+ rewriter.replaceOpWithMultiple (op, {newInterleaveOps});
1520+ return success ();
1521+ }
1522+ };
1523+
1524+ // This pattern transforms vector.deinterleave ops to work at subgroup level.
1525+ struct WgToSgVectorDeinterleaveOp
1526+ : public OpConversionPattern<vector::DeinterleaveOp> {
1527+ using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1528+
1529+ LogicalResult
1530+ matchAndRewrite (vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1531+ ConversionPatternRewriter &rewriter) const override {
1532+ SmallVector<Value> newRes1Ops;
1533+ SmallVector<Value> newRes2Ops;
1534+
1535+ for (auto src : adaptor.getSource ()) {
1536+ auto newDeinterleave =
1537+ vector::DeinterleaveOp::create (rewriter, op.getLoc (), src);
1538+ newRes1Ops.push_back (newDeinterleave.getRes1 ());
1539+ newRes2Ops.push_back (newDeinterleave.getRes2 ());
1540+ }
1541+
1542+ SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1543+ rewriter.replaceOpWithMultiple (op, results);
1544+ return success ();
1545+ }
1546+ };
1547+
14061548} // namespace
14071549
14081550namespace mlir {
14091551namespace xegpu {
14101552void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
14111553 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1412- WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
1413- WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
1414- WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOp,
1415- WgToSgStoreScatterOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
1416- WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1417- WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1418- WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1554+ WgToSgDpasMxOp, WgToSgPrefetchNdOp,
1555+ UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
1556+ WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1557+ WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1558+ WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1559+ WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1560+ WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1561+ WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1562+ WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
14191563 patterns.getContext ());
14201564}
14211565} // namespace xegpu
@@ -1539,6 +1683,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
15391683 return isLegal (layout);
15401684 });
15411685
1686+ target.addDynamicallyLegalOp <xegpu::DpasMxOp>(
1687+ [=](xegpu::DpasMxOp op) -> bool {
1688+ auto layout = op.getLayoutCdAttr ();
1689+ return isLegal (layout);
1690+ });
1691+
15421692 target.addDynamicallyLegalOp <xegpu::LoadMatrixOp>(
15431693 [=](xegpu::LoadMatrixOp op) -> bool {
15441694 return isLegal (op.getLayoutAttr ());
@@ -1560,16 +1710,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
15601710 return isLegal (layout);
15611711 });
15621712
1563- target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp,
1564- vector::TransposeOp , vector::BroadcastOp ,
1565- vector::MultiDimReductionOp,
1566- vector::ConstantMaskOp , vector::CreateMaskOp>(
1567- [=](Operation *op) -> bool {
1568- // Check for either a SliceAttr or LayoutAttr on the result.
1569- auto layout =
1570- xegpu::getTemporaryLayout (dyn_cast<OpResult>(op->getResult (0 )));
1571- return isLegal (layout);
1572- });
1713+ target.addDynamicallyLegalOp <
1714+ vector::ShapeCastOp, vector::StepOp , vector::TransposeOp ,
1715+ vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp ,
1716+ vector::CreateMaskOp, vector::BitCastOp , vector::InterleaveOp,
1717+ vector::DeinterleaveOp>( [=](Operation *op) -> bool {
1718+ // Check for either a SliceAttr or LayoutAttr on the result.
1719+ auto layout =
1720+ xegpu::getTemporaryLayout (dyn_cast<OpResult>(op->getResult (0 )));
1721+ return isLegal (layout);
1722+ });
15731723
15741724 target.addDynamicallyLegalOp <xegpu::LoadGatherOp>(
15751725 [=](xegpu::LoadGatherOp op) -> bool {
0 commit comments