Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[DAGCombine] Simplify partial_reduce_*mla with constant. #138289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025

Conversation

sdesmalen-arm
Copy link
Collaborator

partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)

@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels May 2, 2025
@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Sander de Smalen (sdesmalen-arm)

Changes

partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)


Full diff: https://github.com/llvm/llvm-project/pull/138289.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+34-20)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+142-1)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..345cb4f9fb6ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
-
+  auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt ConstantOne;
+  APInt C;
   if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
   unsigned LHSOpcode = LHS->getOpcode();
-  unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+  if (!ISD::isExtOpcode(LHSOpcode))
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
-  SDValue RHSExtOp = RHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
-    return SDValue();
 
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target
-  auto *Context = DAG.getContext();
+  // Only perform these combines if the target supports folding
+  // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
           TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
+  // -> partial_reduce_umla(acc, x, C)
+  if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+    APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+    unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+    if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+        (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+      return SDValue();
+
+    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+  }
+
+  unsigned RHSOpcode = RHS->getOpcode();
+  if (!ISD::isExtOpcode(RHSOpcode))
+    return SDValue();
+
+  SDValue RHSExtOp = RHS->getOperand(0);
+  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+    return SDValue();
 
   // For a 2-stage extend the signedness of both of the extends must be the
   // same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                      RHSExtOp);
 }
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
   ret <vscale x 2 x i16> %partial.reduce
 }
 
-
 define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: partial_reduce_only_split_acc:
 ; CHECK:       // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
   <vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 4 x i64> %partial.reduce
 }
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sub z0.s, z0.s, z2.s
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sub z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEXT:    mov z2.s, #255 // =0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z4.s, z3.h
+; CHECK-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)


Full diff: https://github.com/llvm/llvm-project/pull/138289.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+34-20)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+142-1)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..345cb4f9fb6ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
-
+  auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt ConstantOne;
+  APInt C;
   if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
   unsigned LHSOpcode = LHS->getOpcode();
-  unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+  if (!ISD::isExtOpcode(LHSOpcode))
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
-  SDValue RHSExtOp = RHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
-    return SDValue();
 
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target
-  auto *Context = DAG.getContext();
+  // Only perform these combines if the target supports folding
+  // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
           TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
+  // -> partial_reduce_umla(acc, x, C)
+  if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+    APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+    unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+    if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+        (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+      return SDValue();
+
+    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+  }
+
+  unsigned RHSOpcode = RHS->getOpcode();
+  if (!ISD::isExtOpcode(RHSOpcode))
+    return SDValue();
+
+  SDValue RHSExtOp = RHS->getOperand(0);
+  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+    return SDValue();
 
   // For a 2-stage extend the signedness of both of the extends must be the
   // same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                      RHSExtOp);
 }
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
   ret <vscale x 2 x i16> %partial.reduce
 }
 
-
 define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: partial_reduce_only_split_acc:
 ; CHECK:       // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
   <vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 4 x i64> %partial.reduce
 }
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sub z0.s, z0.s, z2.s
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sub z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEXT:    mov z2.s, #255 // =0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z4.s, z3.h
+; CHECK-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}

partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)
@sdesmalen-arm sdesmalen-arm force-pushed the partial-reduction-fold branch from f734038 to 6c9e8ec Compare May 2, 2025 19:32
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (bar a nit pick)

@sdesmalen-arm sdesmalen-arm merged commit d90cac9 into llvm:main May 6, 2025
9 of 11 checks passed
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants