-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[DAGCombiner] Fold pattern for srl-shl-zext #138290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alexander Peskov (apeskov) ChangesFold This is equivalent of existing fold chain Profit : Allow to reduce the number of instructions. Full diff: https://github.com/llvm/llvm-project/pull/138290.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..9ddac013be280 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10979,6 +10979,39 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
}
+ // fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y))
+ // c1 <= leadingzeros(zext(y))
+ if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
+ N0.getOpcode() == ISD::XOR)) {
+ SDValue lhs = N0.getOperand(0);
+ SDValue rhs = N0.getOperand(1);
+ SDValue shl;
+ SDValue other;
+ if (lhs.getOpcode() == ISD::SHL) {
+ shl = lhs;
+ other = rhs;
+ } else if (rhs.getOpcode() == ISD::SHL) {
+ shl = rhs;
+ other = lhs;
+ }
+ if (shl.getNode()) {
+ if (shl.getOperand(1).getNode() == N1C) {
+ SDValue zext = shl.getOperand(0);
+ if (zext.getOpcode() == ISD::ZERO_EXTEND) {
+ unsigned numLeadingZeros =
+ zext.getValueType().getSizeInBits() -
+ zext.getOperand(0).getValueType().getSizeInBits();
+ if (N1C->getZExtValue() <= numLeadingZeros) {
+ return DAG.getNode(
+ N0.getOpcode(), SDLoc(N0), VT,
+ DAG.getNode(ISD::SRL, SDLoc(N0), VT, other, SDValue(N1C, 0)),
+ zext);
+ }
+ }
+ }
+ }
+ }
+
// fold operands of srl based on knowledge that the low bits are not
// demanded.
if (SimplifyDemandedBits(SDValue(N, 0)))
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
new file mode 100644
index 0000000000000..6686e8d840c6b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -0,0 +1,40 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+define i64 @test1(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 <= leadingzeros(zext(y))
+;
+; CHECK-LABEL: test1
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test1_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test1_param_1];
+; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[X]], 5;
+; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[SHR]], %[[Y]];
+; CHECK: st.param.b64 [func_retval0], %[[OR]];
+;
+ %ext = zext i32 %y to i64
+ %shl = shl i64 %ext, 5
+ %or = or i64 %x, %shl
+ %srl = lshr i64 %or, 5
+ ret i64 %srl
+}
+
+define i64 @test2(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 > leadingzeros(zext(y)).
+;
+; CHECK-LABEL: test2
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test2_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test2_param_1];
+; CHECK: shl.b64 %[[SHL:rd[0-9]+]], %[[Y]], 33;
+; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
+; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[OR]], 33;
+; CHECK: st.param.b64 [func_retval0], %[[SHR]];
+;
+ %ext = zext i32 %y to i64
+ %shl = shl i64 %ext, 33
+ %or = or i64 %x, %shl
+ %srl = lshr i64 %or, 33
+ ret i64 %srl
+}
|
@llvm/pr-subscribers-llvm-selectiondag Author: Alexander Peskov (apeskov) ChangesFold This is equivalent of existing fold chain Profit : Allow to reduce the number of instructions. Full diff: https://github.com/llvm/llvm-project/pull/138290.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..9ddac013be280 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10979,6 +10979,39 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
}
+ // fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y))
+ // c1 <= leadingzeros(zext(y))
+ if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
+ N0.getOpcode() == ISD::XOR)) {
+ SDValue lhs = N0.getOperand(0);
+ SDValue rhs = N0.getOperand(1);
+ SDValue shl;
+ SDValue other;
+ if (lhs.getOpcode() == ISD::SHL) {
+ shl = lhs;
+ other = rhs;
+ } else if (rhs.getOpcode() == ISD::SHL) {
+ shl = rhs;
+ other = lhs;
+ }
+ if (shl.getNode()) {
+ if (shl.getOperand(1).getNode() == N1C) {
+ SDValue zext = shl.getOperand(0);
+ if (zext.getOpcode() == ISD::ZERO_EXTEND) {
+ unsigned numLeadingZeros =
+ zext.getValueType().getSizeInBits() -
+ zext.getOperand(0).getValueType().getSizeInBits();
+ if (N1C->getZExtValue() <= numLeadingZeros) {
+ return DAG.getNode(
+ N0.getOpcode(), SDLoc(N0), VT,
+ DAG.getNode(ISD::SRL, SDLoc(N0), VT, other, SDValue(N1C, 0)),
+ zext);
+ }
+ }
+ }
+ }
+ }
+
// fold operands of srl based on knowledge that the low bits are not
// demanded.
if (SimplifyDemandedBits(SDValue(N, 0)))
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
new file mode 100644
index 0000000000000..6686e8d840c6b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -0,0 +1,40 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+define i64 @test1(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 <= leadingzeros(zext(y))
+;
+; CHECK-LABEL: test1
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test1_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test1_param_1];
+; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[X]], 5;
+; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[SHR]], %[[Y]];
+; CHECK: st.param.b64 [func_retval0], %[[OR]];
+;
+ %ext = zext i32 %y to i64
+ %shl = shl i64 %ext, 5
+ %or = or i64 %x, %shl
+ %srl = lshr i64 %or, 5
+ ret i64 %srl
+}
+
+define i64 @test2(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 > leadingzeros(zext(y)).
+;
+; CHECK-LABEL: test2
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test2_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test2_param_1];
+; CHECK: shl.b64 %[[SHL:rd[0-9]+]], %[[Y]], 33;
+; CHECK: or.b64 %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
+; CHECK: shr.u64 %[[SHR:rd[0-9]+]], %[[OR]], 33;
+; CHECK: st.param.b64 [func_retval0], %[[SHR]];
+;
+ %ext = zext i32 %y to i64
+ %shl = shl i64 %ext, 33
+ %or = or i64 %x, %shl
+ %srl = lshr i64 %or, 33
+ ret i64 %srl
+}
|
; | ||
%ext = zext i32 %y to i64 | ||
%shl = shl i64 %ext, 5 | ||
%or = or i64 %x, %shl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code has and and xor but those aren't tested here. Also should test vector cases, and negative tests for multiple uses, and not enough known bits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added tests for:
xor
andand
- Vector
or
- Negative. Multiple uses of
logic_op
andshl
What did you mean by "not enough known bits"? Case of "c1 > leadingzeros(zext(y))" was already covered by the test.
llvm/test/CodeGen/NVPTX/shift-opt.ll
Outdated
@@ -0,0 +1,40 @@ | |||
; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s | |||
|
|||
define i64 @test1(i64 %x, i32 %y) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Descriptive function name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reworded test names. Not sure they've become much more descriptive. But significantly better than numbered version.
shl = rhs; | ||
other = lhs; | ||
} | ||
if (shl.getNode()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need getNode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
other = lhs; | ||
} | ||
if (shl.getNode()) { | ||
if (shl.getOperand(1).getNode() == N1C) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be using getNode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// c1 <= leadingzeros(zext(y)) | ||
if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND || | ||
N0.getOpcode() == ISD::XOR)) { | ||
SDValue lhs = N0.getOperand(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capitalize variable names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capitalized.
SDValue zext = shl.getOperand(0); | ||
if (zext.getOpcode() == ISD::ZERO_EXTEND) { | ||
unsigned numLeadingZeros = | ||
zext.getValueType().getSizeInBits() - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to use getScalarSizeInBits to properly support vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Vector test is also added.
Title says "patterm" instead of "pattern" |
if (zext.getOpcode() == ISD::ZERO_EXTEND) { | ||
unsigned numLeadingZeros = | ||
zext.getValueType().getSizeInBits() - | ||
zext.getOperand(0).getValueType().getSizeInBits(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use getScalarSizeInBits() so this correctly handles vector types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y)) | ||
// c1 <= leadingzeros(zext(y)) | ||
if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND || | ||
N0.getOpcode() == ISD::XOR)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N1C && ISD::isBitwiseLogicOp(N0.getOpcode()))
(ideally we'd use sd_match but we're missing m_BitwiseLogic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. ISD::isBitwiseLogicOp
helper is utilized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since #138301 we now have m_BitwiseLogic if you wanted to use SDPatternMatch to simplify the commutative matching - but this is is optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I reworded with sd_match
. Definetly, it's more concise.
I didn't find builtin functionality to matched node with specific opcode, so I used next construction: m_AllOf(m_Value(ZExtY), m_Opc(ISD::ZERO_EXTEND))
. If you know a more elegant solution, please point it out.
other = lhs; | ||
} | ||
if (shl.getNode()) { | ||
if (shl.getOperand(1).getNode() == N1C) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should compare against N1 instead of N1C. N1 might be a constant build_vector or splat_vector in which case N1C is an operand of the build_vector/splat_vector not the srl. So we should check that the shl/srl uses the same build_vector/splat_vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Switched to comparison with N1. The corresponding test is also provided.
✅ With the latest revision this PR passed the C/C++ code formatter. |
fold (srl (or x, (shl (zext y), c1), c1) -> (or (srl x, c1), (zext y)) for c1 <= leadingzeros(zext(y))
Signed-off-by: Alexander Peskov <[email protected]>
4b2cdbb
to
a3ac511
Compare
Fold
(srl (lop x, (shl (zext y), c1)), c1) -> (lop (srl x, c1), (zext y))
where c1 <= leadingzeros(zext(y)).This is equivalent of existing fold chain
(srl (shl (zext y), c1), c1) -> (and (zext y), mask) -> (zext y)
, but logical op in the middle prevents it from combining.Profit : Allow to reduce the number of instructions.