-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op #131326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op #131326
Conversation
ac63453
to
130813b
Compare
130813b
to
f3d9a54
Compare
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: Nicholas Guy (NickGuy-Arm) ChangesGeneric DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and @JamesChesterman is the original author. Patch is 25.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131326.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cb5943eca82f5..d59b54768af6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -618,6 +618,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
+ SDValue foldPartialReduceMLAMulOp(SDNode *N);
+ SDValue foldPartialReduceMLANoMulOp(SDNode *N);
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12612,13 +12614,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ if (SDValue Res = foldPartialReduceMLAMulOp(N))
+ return Res;
+ if (SDValue Res = foldPartialReduceMLANoMulOp(N))
+ return Res;
+ return SDValue();
+}
+
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
-SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
RHSExtOp);
}
+// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
+// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
+SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ unsigned Op1Opcode = Op1.getOpcode();
+ if (!ISD::isExtOpcode(Op1Opcode))
+ return SDValue();
+
+ SDValue UnextOp1 = Op1.getOperand(0);
+ EVT UnextOp1VT = UnextOp1.getValueType();
+
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+ return SDValue();
+
+ SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
+
+ bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+
+ bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
+ if (Op1IsSigned != NodeIsSigned &&
+ (Op1.getValueType().getVectorElementType() != AccElemVT ||
+ Op2.getValueType().getVectorElementType() != AccElemVT))
+ return SDValue();
+
+ unsigned NewOpcode =
+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
+ TruncOp2);
+}
+
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..e9e3bfa255462 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -423,45 +423,45 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z30.d, z4.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z8.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z5.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: mul z7.d, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z28.d
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: mul z6.d, z6.d, z25.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z5.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z7
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z31.d, z9.d
; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z6
; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
@@ -556,45 +556,45 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z30.d, z4.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z8.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z5.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: mul z7.d, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z28.d
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: mul z6.d, z6.d, z25.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z5.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z7
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z31.d, z9.d
; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z6
; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
@@ -620,16 +620,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
-; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -645,16 +637,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
-; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -670,16 +654,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -696,16 +672,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -727,28 +695,13 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z2.b
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
@@ -769,28 +722,13 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 11fb60ead4fb2..5773c47e001d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -172,15 +172,35 @@ entry:
}
define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: signed_wide_add_nxv4i16:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ptrue p0.s
-; CHECK-NEXT: sxth z1.s, p0/m, z1.s
-; CHECK-NEXT: uunpklo z2.d, z1.s
-; CHECK-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z0.d, z1.d, z0.d
-; CHECK-NEXT: ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: ptrue p0.s
+; CHECK-SVE2-NEXT: sxth z1.s, p0/m, z1.s
+; CHECK-SVE2-NEXT: uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT: add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SV...
[truncated]
|
This PR implicitly depends on #130935 |
f3d9a54
to
a708b96
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly some nits, otherwise looks good to me.
APInt ConstantOne; | ||
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || | ||
!ConstantOne.isOne()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
APInt ConstantOne; | |
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || | |
!ConstantOne.isOne()) | |
APInt C; | |
if (!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) |
EVT AccElemVT = Acc.getValueType().getVectorElementType(); | ||
if (Op1IsSigned != NodeIsSigned && | ||
(Op1.getValueType().getVectorElementType() != AccElemVT || | ||
Op2.getValueType().getVectorElementType() != AccElemVT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to test the Op2
case here, because the type of Op1
must match that of Op2
.
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) | ||
return SDValue(); | ||
|
||
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be DAG.getConstant(1, DL, UnextOp1VT)
(and better to just inline that in the use below).
|
||
SDValue UnextOp1 = Op1.getOperand(0); | ||
EVT UnextOp1VT = UnextOp1.getValueType(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove newline.
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2); | ||
|
||
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove newline.
// partial.reduce.umla(acc, op, splat(trunc(1))) | ||
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into | ||
// partial.reduce.smla(acc, op, splat(trunc(1))) | ||
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe rename this to foldPartialReduceAdd
?
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into | ||
// partial.reduce.umla(acc, op, splat(trunc(1))) | ||
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into | ||
// partial.reduce.smla(acc, op, splat(trunc(1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into | |
// partial.reduce.umla(acc, op, splat(trunc(1))) | |
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into | |
// partial.reduce.smla(acc, op, splat(trunc(1))) | |
// partial.reduce.umla(acc, zext(op1), splat(1)) | |
// -> partial.reduce.umla(acc, op, splat(trunc(1))) | |
// partial.reduce.smla(acc, sext(op1), splat(1)) | |
// -> partial.reduce.smla(acc, op, splat(trunc(1))) |
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
ffd894f
to
bc54324
Compare
…lvm#131326) Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))). --------- Co-authored-by: James Chesterman <[email protected]>
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
@JamesChesterman is the original author.