Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[DAGCombiner] Fold pattern for srl-shl-zext #138290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

apeskov
Copy link
Contributor

@apeskov apeskov commented May 2, 2025

Fold (srl (lop x, (shl (zext y), c1)), c1) -> (lop (srl x, c1), (zext y)) where c1 <= leadingzeros(zext(y)).

This is equivalent of existing fold chain (srl (shl (zext y), c1), c1) -> (and (zext y), mask) -> (zext y), but logical op in the middle prevents it from combining.

Profit : Allow to reduce the number of instructions.

@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels May 2, 2025
@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alexander Peskov (apeskov)

Changes

Fold (srl (lop x, (shl (zext y), c1)), c1) -&gt; (lop (srl x, c1), (zext y)) where c1 <= leadingzeros(zext(y)).

This is equivalent of existing fold chain (srl (shl (zext y), c1), c1) -&gt; (and (zext y), mask) -&gt; (zext y), but logical op in the middle prevents it from combining.

Profit : Allow to reduce the number of instructions.


Full diff: https://github.com/llvm/llvm-project/pull/138290.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+33)
  • (added) llvm/test/CodeGen/NVPTX/shift-opt.ll (+40)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..9ddac013be280 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10979,6 +10979,39 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
       return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
   }
 
+  // fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y))
+  // c1 <= leadingzeros(zext(y))
+  if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
+              N0.getOpcode() == ISD::XOR)) {
+    SDValue lhs = N0.getOperand(0);
+    SDValue rhs = N0.getOperand(1);
+    SDValue shl;
+    SDValue other;
+    if (lhs.getOpcode() == ISD::SHL) {
+      shl = lhs;
+      other = rhs;
+    } else if (rhs.getOpcode() == ISD::SHL) {
+      shl = rhs;
+      other = lhs;
+    }
+    if (shl.getNode()) {
+      if (shl.getOperand(1).getNode() == N1C) {
+        SDValue zext = shl.getOperand(0);
+        if (zext.getOpcode() == ISD::ZERO_EXTEND) {
+          unsigned numLeadingZeros =
+              zext.getValueType().getSizeInBits() -
+              zext.getOperand(0).getValueType().getSizeInBits();
+          if (N1C->getZExtValue() <= numLeadingZeros) {
+            return DAG.getNode(
+                N0.getOpcode(), SDLoc(N0), VT,
+                DAG.getNode(ISD::SRL, SDLoc(N0), VT, other, SDValue(N1C, 0)),
+                zext);
+          }
+        }
+      }
+    }
+  }
+
   // fold operands of srl based on knowledge that the low bits are not
   // demanded.
   if (SimplifyDemandedBits(SDValue(N, 0)))
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
new file mode 100644
index 0000000000000..6686e8d840c6b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -0,0 +1,40 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+define i64 @test1(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 <= leadingzeros(zext(y))
+;
+; CHECK-LABEL: test1
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test1_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test1_param_1];
+; CHECK: shr.u64      %[[SHR:rd[0-9]+]], %[[X]], 5;
+; CHECK: or.b64       %[[OR:rd[0-9]+]], %[[SHR]], %[[Y]];
+; CHECK: st.param.b64 [func_retval0], %[[OR]];
+;
+  %ext = zext i32 %y to i64
+  %shl = shl i64 %ext, 5
+  %or = or i64 %x, %shl
+  %srl = lshr i64 %or, 5
+  ret i64 %srl
+}
+
+define i64 @test2(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 > leadingzeros(zext(y)).
+;
+; CHECK-LABEL: test2
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test2_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test2_param_1];
+; CHECK: shl.b64      %[[SHL:rd[0-9]+]], %[[Y]], 33;
+; CHECK: or.b64       %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
+; CHECK: shr.u64      %[[SHR:rd[0-9]+]], %[[OR]], 33;
+; CHECK: st.param.b64 [func_retval0], %[[SHR]];
+;
+  %ext = zext i32 %y to i64
+  %shl = shl i64 %ext, 33
+  %or = or i64 %x, %shl
+  %srl = lshr i64 %or, 33
+  ret i64 %srl
+}

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Alexander Peskov (apeskov)

Changes

Fold (srl (lop x, (shl (zext y), c1)), c1) -&gt; (lop (srl x, c1), (zext y)) where c1 <= leadingzeros(zext(y)).

This is equivalent of existing fold chain (srl (shl (zext y), c1), c1) -&gt; (and (zext y), mask) -&gt; (zext y), but logical op in the middle prevents it from combining.

Profit : Allow to reduce the number of instructions.


Full diff: https://github.com/llvm/llvm-project/pull/138290.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+33)
  • (added) llvm/test/CodeGen/NVPTX/shift-opt.ll (+40)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..9ddac013be280 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10979,6 +10979,39 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
       return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
   }
 
+  // fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y))
+  // c1 <= leadingzeros(zext(y))
+  if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
+              N0.getOpcode() == ISD::XOR)) {
+    SDValue lhs = N0.getOperand(0);
+    SDValue rhs = N0.getOperand(1);
+    SDValue shl;
+    SDValue other;
+    if (lhs.getOpcode() == ISD::SHL) {
+      shl = lhs;
+      other = rhs;
+    } else if (rhs.getOpcode() == ISD::SHL) {
+      shl = rhs;
+      other = lhs;
+    }
+    if (shl.getNode()) {
+      if (shl.getOperand(1).getNode() == N1C) {
+        SDValue zext = shl.getOperand(0);
+        if (zext.getOpcode() == ISD::ZERO_EXTEND) {
+          unsigned numLeadingZeros =
+              zext.getValueType().getSizeInBits() -
+              zext.getOperand(0).getValueType().getSizeInBits();
+          if (N1C->getZExtValue() <= numLeadingZeros) {
+            return DAG.getNode(
+                N0.getOpcode(), SDLoc(N0), VT,
+                DAG.getNode(ISD::SRL, SDLoc(N0), VT, other, SDValue(N1C, 0)),
+                zext);
+          }
+        }
+      }
+    }
+  }
+
   // fold operands of srl based on knowledge that the low bits are not
   // demanded.
   if (SimplifyDemandedBits(SDValue(N, 0)))
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
new file mode 100644
index 0000000000000..6686e8d840c6b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -0,0 +1,40 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+define i64 @test1(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 <= leadingzeros(zext(y))
+;
+; CHECK-LABEL: test1
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test1_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test1_param_1];
+; CHECK: shr.u64      %[[SHR:rd[0-9]+]], %[[X]], 5;
+; CHECK: or.b64       %[[OR:rd[0-9]+]], %[[SHR]], %[[Y]];
+; CHECK: st.param.b64 [func_retval0], %[[OR]];
+;
+  %ext = zext i32 %y to i64
+  %shl = shl i64 %ext, 5
+  %or = or i64 %x, %shl
+  %srl = lshr i64 %or, 5
+  ret i64 %srl
+}
+
+define i64 @test2(i64 %x, i32 %y) {
+;
+; srl (or (x, shl(zext(y),c1)),c1) -> or(srl(x,c1), zext(y))
+; c1 > leadingzeros(zext(y)).
+;
+; CHECK-LABEL: test2
+; CHECK: ld.param.u64 %[[X:rd[0-9]+]], [test2_param_0];
+; CHECK: ld.param.u32 %[[Y:rd[0-9]+]], [test2_param_1];
+; CHECK: shl.b64      %[[SHL:rd[0-9]+]], %[[Y]], 33;
+; CHECK: or.b64       %[[OR:rd[0-9]+]], %[[X]], %[[SHL]];
+; CHECK: shr.u64      %[[SHR:rd[0-9]+]], %[[OR]], 33;
+; CHECK: st.param.b64 [func_retval0], %[[SHR]];
+;
+  %ext = zext i32 %y to i64
+  %shl = shl i64 %ext, 33
+  %or = or i64 %x, %shl
+  %srl = lshr i64 %or, 33
+  ret i64 %srl
+}

;
%ext = zext i32 %y to i64
%shl = shl i64 %ext, 5
%or = or i64 %x, %shl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code has and and xor but those aren't tested here. Also should test vector cases, and negative tests for multiple uses, and not enough known bits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for:

  • xor and and
  • Vector or
  • Negative. Multiple uses of logic_op and shl

What did you mean by "not enough known bits"? Case of "c1 > leadingzeros(zext(y))" was already covered by the test.

@@ -0,0 +1,40 @@
; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s

define i64 @test1(i64 %x, i32 %y) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Descriptive function name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworded test names. Not sure they've become much more descriptive. But significantly better than numbered version.

shl = rhs;
other = lhs;
}
if (shl.getNode()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need getNode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

other = lhs;
}
if (shl.getNode()) {
if (shl.getOperand(1).getNode() == N1C) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be using getNode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// c1 <= leadingzeros(zext(y))
if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
N0.getOpcode() == ISD::XOR)) {
SDValue lhs = N0.getOperand(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalize variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalized.

SDValue zext = shl.getOperand(0);
if (zext.getOpcode() == ISD::ZERO_EXTEND) {
unsigned numLeadingZeros =
zext.getValueType().getSizeInBits() -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use getScalarSizeInBits to properly support vectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Vector test is also added.

@topperc
Copy link
Collaborator

topperc commented May 2, 2025

Title says "patterm" instead of "pattern"

@apeskov apeskov changed the title [DAGCombiner] Fold patterm for srl-shl-zext [DAGCombiner] Fold pattern for srl-shl-zext May 2, 2025
if (zext.getOpcode() == ISD::ZERO_EXTEND) {
unsigned numLeadingZeros =
zext.getValueType().getSizeInBits() -
zext.getOperand(0).getValueType().getSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use getScalarSizeInBits() so this correctly handles vector types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// fold (srl (or x, (shl (zext y), c1)), c1) -> (or (srl x, c1), (zext y))
// c1 <= leadingzeros(zext(y))
if (N1C && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND ||
N0.getOpcode() == ISD::XOR)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N1C && ISD::isBitwiseLogicOp(N0.getOpcode()))

(ideally we'd use sd_match but we're missing m_BitwiseLogic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. ISD::isBitwiseLogicOp helper is utilized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since #138301 we now have m_BitwiseLogic if you wanted to use SDPatternMatch to simplify the commutative matching - but this is is optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I reworded with sd_match. Definetly, it's more concise.

I didn't find builtin functionality to matched node with specific opcode, so I used next construction: m_AllOf(m_Value(ZExtY), m_Opc(ISD::ZERO_EXTEND)). If you know a more elegant solution, please point it out.

other = lhs;
}
if (shl.getNode()) {
if (shl.getOperand(1).getNode() == N1C) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should compare against N1 instead of N1C. N1 might be a constant build_vector or splat_vector in which case N1C is an operand of the build_vector/splat_vector not the srl. So we should check that the shl/srl uses the same build_vector/splat_vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Switched to comparison with N1. The corresponding test is also provided.

Copy link

github-actions bot commented May 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@apeskov apeskov requested review from arsenm, topperc and RKSimon May 6, 2025 19:32
@apeskov apeskov force-pushed the ap/fold-srl-zext-shl branch from 4b2cdbb to a3ac511 Compare May 7, 2025 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants