Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op #131326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

NickGuy-Arm
Copy link
Contributor

Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).

@JamesChesterman is the original author.

@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/3 branch from ac63453 to 130813b Compare April 15, 2025 16:15
@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/3 branch from 130813b to f3d9a54 Compare April 23, 2025 13:12
@NickGuy-Arm NickGuy-Arm marked this pull request as ready for review April 23, 2025 13:17
@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Apr 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Nicholas Guy (NickGuy-Arm)

Changes

Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).

@JamesChesterman is the original author.


Patch is 25.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131326.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+53-1)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+76-138)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll (+99-31)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cb5943eca82f5..d59b54768af6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -618,6 +618,8 @@ namespace {
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
                                  const TargetLowering &TLI);
+    SDValue foldPartialReduceMLAMulOp(SDNode *N);
+    SDValue foldPartialReduceMLANoMulOp(SDNode *N);
 
     SDValue CombineExtLoad(SDNode *N);
     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12612,13 +12614,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+  if (SDValue Res = foldPartialReduceMLAMulOp(N))
+    return Res;
+  if (SDValue Res = foldPartialReduceMLANoMulOp(N))
+    return Res;
+  return SDValue();
+}
+
 // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
 // Splat(1)) into
 // PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
 // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
 // Splat(1)) into
 // PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
-SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
 
   SDValue Acc = N->getOperand(0);
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
                      RHSExtOp);
 }
 
+// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
+// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
+SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  unsigned Op1Opcode = Op1.getOpcode();
+  if (!ISD::isExtOpcode(Op1Opcode))
+    return SDValue();
+
+  SDValue UnextOp1 = Op1.getOperand(0);
+  EVT UnextOp1VT = UnextOp1.getValueType();
+
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+    return SDValue();
+
+  SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
+
+  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+
+  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  EVT AccElemVT = Acc.getValueType().getVectorElementType();
+  if (Op1IsSigned != NodeIsSigned &&
+      (Op1.getValueType().getVectorElementType() != AccElemVT ||
+       Op2.getValueType().getVectorElementType() != AccElemVT))
+    return SDValue();
+
+  unsigned NewOpcode =
+      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
+                     TruncOp2);
+}
+
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
   auto *SLD = cast<VPStridedLoadSDNode>(N);
   EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..e9e3bfa255462 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -423,45 +423,45 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
 ; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
 ; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z30.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z8.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z9.d, z5.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mul z7.d, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z28.d
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mul z6.d, z6.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z7
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z31.d, z9.d
 ; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z6
 ; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
 ; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
@@ -556,45 +556,45 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
 ; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
 ; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z3.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z30.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z8.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z9.d, z5.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mul z7.d, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z28.d
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mul z6.d, z6.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z7
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z31.d, z9.d
 ; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z6
 ; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
 ; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
@@ -620,16 +620,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z1.s, z2.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z4.s, z0.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -645,16 +637,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z1.s, z2.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z4.s, z0.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -670,16 +654,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -696,16 +672,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -727,28 +695,13 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z4.d, z25.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z7.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT:    udot z1.d, z2.h, z3.h
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
@@ -769,28 +722,13 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z4.d, z25.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z7.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT:    sdot z1.d, z2.h, z3.h
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 11fb60ead4fb2..5773c47e001d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -172,15 +172,35 @@ entry:
 }
 
 define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: signed_wide_add_nxv4i16:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    sxth z1.s, p0/m, z1.s
-; CHECK-NEXT:    uunpklo z2.d, z1.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEXT:    ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    ptrue p0.s
+; CHECK-SVE2-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-SVE2-NEXT:    uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SV...
[truncated]

@NickGuy-Arm
Copy link
Contributor Author

This PR implicitly depends on #130935

@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/3 branch from f3d9a54 to a708b96 Compare May 2, 2025 13:29
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly some nits, otherwise looks good to me.

Comment on lines 12692 to 12694
APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
APInt C;
if (!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())

EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
(Op1.getValueType().getVectorElementType() != AccElemVT ||
Op2.getValueType().getVectorElementType() != AccElemVT))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to test the Op2 case here, because the type of Op1 must match that of Op2.

if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
return SDValue();

SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be DAG.getConstant(1, DL, UnextOp1VT) (and better to just inline that in the use below).


SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline.

SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);

bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline.

// partial.reduce.umla(acc, op, splat(trunc(1)))
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
// partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename this to foldPartialReduceAdd?

Comment on lines 12682 to 12685
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
// partial.reduce.umla(acc, op, splat(trunc(1)))
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
// partial.reduce.smla(acc, op, splat(trunc(1)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
// partial.reduce.umla(acc, op, splat(trunc(1)))
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
// partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.umla(acc, zext(op1), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
// partial.reduce.smla(acc, sext(op1), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))

Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert:
PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1)))
and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/3 branch from ffd894f to bc54324 Compare May 6, 2025 13:29
@NickGuy-Arm NickGuy-Arm merged commit a8ed244 into llvm:main May 6, 2025
9 of 11 checks passed
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…lvm#131326)

Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert:
PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).

---------

Co-authored-by: James Chesterman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants