Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e6a72a1

Browse files
authored
[RISCV] Combine ADDD+WMULSU to WMACCSU (llvm#180454)
Extend the existing combineADDDToWMACC DAG combine to also match RISCVISD::WMULSU and produce RISCVISD::WMACCSU. This is similar to how ADDD+UMUL_LOHI is combined to WMACCU and ADDD+SMUL_LOHI is combined to WMACC. This patch was generated by AI, but I reviewed it.
1 parent e6c73eb commit e6a72a1

4 files changed

Lines changed: 73 additions & 15 deletions

File tree

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,12 +1910,13 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
19101910
CurDAG->RemoveDeadNode(Node);
19111911
return;
19121912
}
1913+
case RISCVISD::WMACCSU:
19131914
case RISCVISD::WMACCU:
19141915
case RISCVISD::WMACC: {
19151916
assert(!Subtarget->is64Bit() && Subtarget->hasStdExtP() &&
19161917
"Unexpected opcode");
19171918

1918-
// WMACCU/WMACC has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
1919+
// WMACCU/WMACC/WMACCSU has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
19191920
SDValue M1 = Node->getOperand(0);
19201921
SDValue M2 = Node->getOperand(1);
19211922
SDValue AddLo = Node->getOperand(2);
@@ -1930,8 +1931,20 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
19301931
MVT::Untyped, AccOps),
19311932
0);
19321933

1933-
unsigned Opc =
1934-
Node->getOpcode() == RISCVISD::WMACCU ? RISCV::WMACCU : RISCV::WMACC;
1934+
unsigned Opc;
1935+
switch (Node->getOpcode()) {
1936+
default:
1937+
llvm_unreachable("Unexpected WMACC opcode");
1938+
case RISCVISD::WMACCU:
1939+
Opc = RISCV::WMACCU;
1940+
break;
1941+
case RISCVISD::WMACC:
1942+
Opc = RISCV::WMACC;
1943+
break;
1944+
case RISCVISD::WMACCSU:
1945+
Opc = RISCV::WMACCSU;
1946+
break;
1947+
}
19351948

19361949
// Instruction format: WMACCU rd, rs1, rs2 (rd is accumulator, comes first)
19371950
MachineSDNode *New =

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21067,35 +21067,39 @@ static SDValue performSHLCombine(SDNode *N,
2106721067
// (WMACCU x, y, a, b).
2106821068
// Combine (ADDD (SMUL_LOHI x, y).0, (SMUL_LOHI x, y).1, a, b) into
2106921069
// (WMACC x, y, a, b).
21070+
// Combine (ADDD (WMULSU x, y).0, (WMULSU x, y).1, a, b) into
21071+
// (WMACCSU x, y, a, b).
2107021072
static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
2107121073
const RISCVSubtarget &Subtarget) {
2107221074
assert(N->getOpcode() == RISCVISD::ADDD && "Expected ADDD");
2107321075
assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() &&
2107421076
"ADDD requires RV32 with P extension");
2107521077

2107621078
// ADDD has 4 operands: (op0_lo, op0_hi, op1_lo, op1_hi)
21077-
// Try to match UMUL_LOHI or SMUL_LOHI in either operand pair due to
21079+
// Try to match UMUL_LOHI, SMUL_LOHI, or WMULSU in either operand pair due to
2107821080
// commutativity
2107921081
SDValue Op0Lo = N->getOperand(0);
2108021082
SDValue Op0Hi = N->getOperand(1);
2108121083
SDValue Op1Lo = N->getOperand(2);
2108221084
SDValue Op1Hi = N->getOperand(3);
2108321085

21086+
auto IsSupportedMul = [](unsigned Opc) {
21087+
return Opc == ISD::UMUL_LOHI || Opc == ISD::SMUL_LOHI ||
21088+
Opc == RISCVISD::WMULSU;
21089+
};
21090+
2108421091
SDNode *MulNode = nullptr;
2108521092
SDValue AddLo, AddHi;
2108621093

21087-
// Check if first operand pair is UMUL_LOHI or SMUL_LOHI
21088-
if ((Op0Lo.getOpcode() == ISD::UMUL_LOHI ||
21089-
Op0Lo.getOpcode() == ISD::SMUL_LOHI) &&
21090-
Op0Lo.getNode() == Op0Hi.getNode() && Op0Lo.getResNo() == 0 &&
21091-
Op0Hi.getResNo() == 1) {
21094+
// Check if first operand pair is a supported multiply
21095+
if (IsSupportedMul(Op0Lo.getOpcode()) && Op0Lo.getNode() == Op0Hi.getNode() &&
21096+
Op0Lo.getResNo() == 0 && Op0Hi.getResNo() == 1) {
2109221097
MulNode = Op0Lo.getNode();
2109321098
AddLo = Op1Lo;
2109421099
AddHi = Op1Hi;
2109521100
}
21096-
// Check if second operand pair is UMUL_LOHI or SMUL_LOHI (commutative case)
21097-
else if ((Op1Lo.getOpcode() == ISD::UMUL_LOHI ||
21098-
Op1Lo.getOpcode() == ISD::SMUL_LOHI) &&
21101+
// Check if second operand pair is a supported multiply (commutative case)
21102+
else if (IsSupportedMul(Op1Lo.getOpcode()) &&
2109921103
Op1Lo.getNode() == Op1Hi.getNode() && Op1Lo.getResNo() == 0 &&
2110021104
Op1Hi.getResNo() == 1) {
2110121105
MulNode = Op1Lo.getNode();
@@ -21113,10 +21117,22 @@ static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
2111321117
SDValue MulOp0 = MulNode->getOperand(0);
2111421118
SDValue MulOp1 = MulNode->getOperand(1);
2111521119

21116-
// Create WMACCU or WMACC node: (m1, m2, addlo, addhi) -> (lo, hi)
21120+
// Create WMACCU, WMACC, or WMACCSU node: (m1, m2, addlo, addhi) -> (lo, hi)
2111721121
SDLoc DL(N);
21118-
bool IsSigned = MulNode->getOpcode() == ISD::SMUL_LOHI;
21119-
unsigned Opc = IsSigned ? RISCVISD::WMACC : RISCVISD::WMACCU;
21122+
unsigned Opc;
21123+
switch (MulNode->getOpcode()) {
21124+
default:
21125+
llvm_unreachable("Unexpected multiply opcode");
21126+
case ISD::UMUL_LOHI:
21127+
Opc = RISCVISD::WMACCU;
21128+
break;
21129+
case ISD::SMUL_LOHI:
21130+
Opc = RISCVISD::WMACC;
21131+
break;
21132+
case RISCVISD::WMULSU:
21133+
Opc = RISCVISD::WMACCSU;
21134+
break;
21135+
}
2112021136
return DAG.getNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::i32), MulOp0, MulOp1,
2112121137
AddLo, AddHi);
2112221138
}

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,6 +1510,7 @@ def riscv_wmacc : RVSDNode<"WMACC", SDT_RISCVWideningMAccW,
15101510
[SDNPCommutative]>;
15111511
def riscv_wmaccu : RVSDNode<"WMACCU", SDT_RISCVWideningMAccW,
15121512
[SDNPCommutative]>;
1513+
def riscv_wmaccsu : RVSDNode<"WMACCSU", SDT_RISCVWideningMAccW>;
15131514

15141515
// MULH/MULHU/MULHSU with rounding.
15151516
def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>;

llvm/test/CodeGen/RISCV/rv32p.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,34 @@ define i64 @wmacc_commute(i32 %a, i32 %b, i64 %c) nounwind {
678678
ret i64 %result
679679
}
680680

681+
define i64 @wmaccsu(i32 %a, i32 %b, i64 %c) nounwind {
682+
; CHECK-LABEL: wmaccsu:
683+
; CHECK: # %bb.0:
684+
; CHECK-NEXT: wmaccsu a2, a0, a1
685+
; CHECK-NEXT: mv a0, a2
686+
; CHECK-NEXT: mv a1, a3
687+
; CHECK-NEXT: ret
688+
%aext = sext i32 %a to i64
689+
%bext = zext i32 %b to i64
690+
%mul = mul i64 %aext, %bext
691+
%result = add i64 %c, %mul
692+
ret i64 %result
693+
}
694+
695+
define i64 @wmaccsu_commute(i32 %a, i32 %b, i64 %c) nounwind {
696+
; CHECK-LABEL: wmaccsu_commute:
697+
; CHECK: # %bb.0:
698+
; CHECK-NEXT: wmaccsu a2, a0, a1
699+
; CHECK-NEXT: mv a0, a2
700+
; CHECK-NEXT: mv a1, a3
701+
; CHECK-NEXT: ret
702+
%aext = sext i32 %a to i64
703+
%bext = zext i32 %b to i64
704+
%mul = mul i64 %aext, %bext
705+
%result = add i64 %mul, %c
706+
ret i64 %result
707+
}
708+
681709
; Negative test: multiply result has multiple uses, should not combine
682710
define void @wmaccu_multiple_uses(i32 %a, i32 %b, i64 %c, ptr %out1, ptr %out2) nounwind {
683711
; CHECK-LABEL: wmaccu_multiple_uses:

0 commit comments

Comments
 (0)