-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA #130935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA #130935
Conversation
45fd2fd
to
9ebecba
Compare
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.
9ebecba
to
4ee4990
Compare
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesImplement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. @JamesChesterman is the original author Full diff: https://github.com/llvm/llvm-project/pull/130935.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index abe261728a3e6..7b0e15f951681 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1668,6 +1668,12 @@ class TargetLoweringBase {
return Action == Legal || Action == Custom;
}
+ /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+ /// legal for this target.
+ bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
+ return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
+ }
+
/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index a01e1cff74564..d0ae436a8758f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc DL(N);
- SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
- std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+
+ // If the node has not gone through the DAG combine, then do not attempt to
+ // legalise, just expand.
+ if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
+ SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
+ std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ return;
+ }
+
+ SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
+ std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
+ std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
+ std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
+ unsigned Opcode = N->getOpcode();
+ EVT ResultVT = AccLo.getValueType();
+
+ Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
+ Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
}
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- return TLI.expandPartialReduceMLA(N, DAG);
+ SDValue Lo, Hi;
+ SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..810d42635e7b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1604,6 +1604,26 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MSTORE, VT, Custom);
}
+ if (EnablePartialReduceNodes) {
+ for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+ for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+ // 1. Set all combinations where a type is illegal to "Legal"
+ // - These will be legalized to a legal type pair
+ // - Avoid expanding them too early (or preventing folds)
+ if (!isTypeLegal(VT) || !isTypeLegal(InnerVT)) {
+ setPartialReduceMLAAction(VT, InnerVT, Legal);
+ continue;
+ }
+ // 2. Set all legal combinations to "Expand"
+ // - Not all of these can be lowered (via a Legal or Custom lowering).
+ setPartialReduceMLAAction(VT, InnerVT, Expand);
+ }
+ }
+ // 3. Mark known legal pairs as 'Legal' (these will expand to USDOT).
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ }
+
// Firstly, exclude all scalable vector extending loads/truncating stores,
// include both integer and floating scalable vector.
for (MVT VT : MVT::scalable_vector_valuetypes()) {
@@ -1856,6 +1876,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
}
// Handle operations that are only available in non-streaming SVE mode.
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..71936b686be15 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -259,6 +259,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
|
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom); | ||
|
||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom); | ||
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom); | ||
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no custom lowering implemented for any of these pairs of types in AArch64TargetLowering::LowerOperation
. If a custom lowering was attempted we'd hit an "unimplemented operand" assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved these to the commit that introduces the lowering implementation.
; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h | ||
; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These test changes look odd? Looks like nothing changed but we now have extra sdot/udot instructions (i.e. no deletions or changes to the code above this)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure where those came from. Removed
// If the node has not gone through the DAG combine, then do not attempt to | ||
// legalise, just expand. | ||
if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) { | ||
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG); | ||
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is this to disable the splitting for all the pairs marked Custom
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When it gets here, it has already been decided that the PARTIAL_REDUCE_MLA form is the right form, just that one/both of the operand needs splitting. There's no need to reconsider this decision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this check, and altered SplitVecOp_PARTIAL_REDUCE_MLA to split only the operands, instead of the operands and the accumulator.
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8 | |||
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b | |||
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b | |||
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b | |||
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the s/udot_8to64
cases would be better handled with a DAG combine that changes the accumulator to i32
(which is what happens with the old lowering).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still looking into this; Ideally we'd make use of the existing code used by the old lowering, so I intend to find out what is needed for that before simply reimplementing the DAG combine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick follow-up; The old lowering takes in the partial reduction intrinsic (via ISD::INTRINSIC_WO_CHAIN), and would take a while to unpick, so I'll push it out to a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is a need for a DAG combine here. You can mark PARTIAL_REDUCE_UMLA(nxv2i64, nxv16i8, nxv16i8)
as Custom
, and then lower this to:
nxv2i64 partial.reduce.umla(nxv2i64 %acc, nxv16i8 %op1, nxv16i8 %op2)
->
%t = nxv4i32 partial.reduce.umla(nxv4i32 zeroinitializer, nxv16i8 %op1, nxv16i8 %op2)
%t.lo = nxv2i64 UUNPKLO nxv4i32 %t
%t.hi = nxv2i64 UUNPKHI nxv4i32 %t
%t.add = nxv2i64 ADD %t.lo, %t.hi
%acc.add = nxv2i64 ADD %acc, %t.add
return; | ||
} | ||
|
||
SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input types don't need splitting, then instead of actually splitting the vector result and input operands it would be simpler (and more efficient) to reduce into the lower/higher half of the accumulator.
e.g.
nxv4i64 partial.reduce.mla(nxv4i64 %acc, nxv8i64 mul(zext(nxv8i16 %a), zext(nxv8i16 %b))
=>
nxv4i64 insert.subvector(nxv4i64 %acc,
nxv2i64 partial.reduce.mla(nxv2i64 extract.subvector(nxv4i64 %acc, 0),
nxv8i64 mul(zext(nxv8i16 %a), zext(nxv8i16 %b), 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've opted to have the partial.reduce.mla share the accumulator, as the semantics of the partial reduction allows for the vector elements to be processed in any order.
This has caused a number of additional test changes however, but as far as I can tell they are all still correct in behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure that I understand what you mean about sharing the accumulator, but the example I had in mind is:
define <vscale x 8 x i32> @udot_split_result_only(<vscale x 8 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add.nxv8i32.nxv16i32(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 8 x i32> %partial.reduce
}
When I write the following code to handle that:
unsigned Opcode = N->getOpcode();
SDValue AccLo, AccHi;
std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
// If the input types don't need splitting, just accumulate into the
// low part of the accumulator.
if (getTypeAction(Input1.getValueType()) != TargetLowering::TypeSplitVector) {
Lo = DAG.getNode(Opcode, DL, AccLo.getValueType(), AccLo, Input1, Input2);
Hi = AccHi;
return;
}
then the above test results in a single dot instruction, and I don't see any existing tests fail/change when I try this (which suggests there is no test-coverage yet for this case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I got my wires crossed. The accumulator is shared when we only need to split the operands (as part of SplitVecOp_PARTIAL_REDUCE_MLA
).
Your suggestion works for the inverse case though, thanks.
I'll look into rectifying the missing test coverage now.
// If the node has not gone through the DAG combine, then do not attempt to | ||
// legalise, just expand. | ||
if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) { | ||
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG); | ||
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When it gets here, it has already been decided that the PARTIAL_REDUCE_MLA form is the right form, just that one/both of the operand needs splitting. There's no need to reconsider this decision.
return; | ||
} | ||
|
||
SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure that I understand what you mean about sharing the accumulator, but the example I had in mind is:
define <vscale x 8 x i32> @udot_split_result_only(<vscale x 8 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add.nxv8i32.nxv16i32(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 8 x i32> %partial.reduce
}
When I write the following code to handle that:
unsigned Opcode = N->getOpcode();
SDValue AccLo, AccHi;
std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
// If the input types don't need splitting, just accumulate into the
// low part of the accumulator.
if (getTypeAction(Input1.getValueType()) != TargetLowering::TypeSplitVector) {
Lo = DAG.getNode(Opcode, DL, AccLo.getValueType(), AccLo, Input1, Input2);
Hi = AccHi;
return;
}
then the above test results in a single dot instruction, and I don't see any existing tests fail/change when I try this (which suggests there is no test-coverage yet for this case)
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is | ||
/// legal for this target. | ||
bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const { | ||
return getPartialReduceMLAAction(AccVT, InputVT) == Legal; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unused in the current patch.
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is | |
/// legal for this target. | |
bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const { | |
return getPartialReduceMLAAction(AccVT, InputVT) == Legal; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, removed.
|
||
// If the input types don't need splitting, just accumulate into the | ||
// low part of the accumulator. | ||
if (getTypeAction(Input1.getValueType()) == TargetLowering::TypeSplitVector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition should check that Input1
does not need splitting.
if (getTypeAction(Input1.getValueType()) == TargetLowering::TypeSplitVector) { | |
if (getTypeAction(Input1.getValueType()) != TargetLowering::TypeSplitVector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I can even call that a typo, with them being on opposite ends of the keyboard.
Fixed, and the new test does catch when this is wrong/not present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM with nits addressed.
…#130935) Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used. --------- Co-authored-by: James Chesterman <[email protected]>
…#130935) Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used. --------- Co-authored-by: James Chesterman <[email protected]>
…#130935) Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used. --------- Co-authored-by: James Chesterman <[email protected]>
…#130935) Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used. --------- Co-authored-by: James Chesterman <[email protected]>
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes.
This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.
@JamesChesterman is the original author