Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[NVPTX] Add intrinsics for the szext instruction #139126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 9, 2025

Conversation

AlexMaclean
Copy link
Member

This change adds support for llvm.nvvm.{sext,zext}.inreg.{wrap,clamp} intrinsics.

@llvmbot
Copy link
Member

llvmbot commented May 8, 2025

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Alex MacLean (AlexMaclean)

Changes

This change adds support for llvm.nvvm.{sext,zext}.inreg.{wrap,clamp} intrinsics.


Full diff: https://github.com/llvm/llvm-project/pull/139126.diff

5 Files Affected:

  • (modified) llvm/docs/NVPTXUsage.rst (+93)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+11)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+33-41)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+15)
  • (added) llvm/test/CodeGen/NVPTX/szext.ll (+65)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index c1426823d87af..331a4b8e08883 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -568,6 +568,99 @@ to left-shift the found bit into the most-significant bit position, otherwise
 the result is the shift amount needed to right-shift the found bit into the
 least-significant bit position. 0xffffffff is returned if no 1 bit is found.
 
+'``llvm.nvvm.zext.inreg.clamp``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.zext.inreg.clamp``' intrinsic extracts the low bits of the
+input value, and zero-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.zext.inreg.clamp``' returns the zero-extension of N lowest bits
+of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
+is 0, the result is 0.
+
+'``llvm.nvvm.zext.inreg.wrap``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.zext.inreg.wrap``' intrinsic extracts the low bits of the
+input value, and zero-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.zext.inreg.wrap``' returns the zero-extension of N lowest bits
+of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
+is 0.
+
+'``llvm.nvvm.sext.inreg.clamp``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.sext.inreg.clamp``' intrinsic extracts the low bits of the
+input value, and sign-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.sext.inreg.clamp``' returns the sign-extension of N lowest bits
+of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
+is 0, the result is 0.
+
+
+'``llvm.nvvm.sext.inreg.wrap``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.sext.inreg.wrap``' intrinsic extracts the low bits of the
+input value, and sign-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.sext.inreg.wrap``' returns the sign-extension of N lowest bits
+of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
+is 0.
+
 TMA family of Intrinsics
 ------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 8b87822d3fdda..65f0e2209fc6b 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1356,6 +1356,17 @@ let TargetPrefix = "nvvm" in {
         [llvm_anyint_ty, llvm_i1_ty],
         [IntrNoMem, IntrSpeculatable, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
 
+
+//
+// szext
+//
+  foreach ext = ["sext", "zext"] in
+    foreach mode = ["wrap", "clamp"] in
+      def int_nvvm_ # ext # _inreg_ # mode :
+        DefaultAttrsIntrinsic<[llvm_i32_ty],
+          [llvm_i32_ty, llvm_i32_ty],
+          [IntrNoMem, IntrSpeculatable]>;
+
 //
 // Convert
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 11d77599d4ac3..dae6c929eea9e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -240,26 +240,33 @@ def F16X2RT  : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
 def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
 
 
+multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
+                  bit commutative, list<Predicate> requires = []> {
+  defvar asmstr = op_str # " \t$dst, $a, $b;";
+
+  def rr :
+    NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
+              asmstr,
+              [(set t.Ty:$dst, (op_node t.Ty:$a, t.Ty:$b))]>,
+              Requires<requires>;
+  def ri :
+    NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
+              asmstr,
+              [(set t.Ty:$dst, (op_node t.RC:$a, imm:$b))]>,
+              Requires<requires>;
+  if !not(commutative) then
+    def ir :
+      NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
+                asmstr,
+                [(set t.Ty:$dst, (op_node imm:$a, t.RC:$b))]>,
+                Requires<requires>;
+}
+
 // Template for instructions which take three int64, int32, or int16 args.
 // The instructions are named "<OpcStr><Width>" (e.g. "add.s64").
-multiclass I3<string OpcStr, SDNode OpNode, bit commutative> {
-  foreach t = [I16RT, I32RT, I64RT] in {
-    defvar asmstr = OpcStr # t.Size # " \t$dst, $a, $b;";
-
-    def t.Ty # rr :
-      NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
-                asmstr,
-                [(set t.Ty:$dst, (OpNode t.Ty:$a, t.Ty:$b))]>;
-    def t.Ty # ri :
-      NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
-                asmstr,
-                [(set t.Ty:$dst, (OpNode t.RC:$a, imm:$b))]>;
-    if !not(commutative) then
-      def t.Ty # ir :
-        NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
-                  asmstr,
-                  [(set t.Ty:$dst, (OpNode imm:$a, t.RC:$b))]>;
-  }
+multiclass I3<string op_str, SDPatternOperator op_node, bit commutative> {
+  foreach t = [I16RT, I32RT, I64RT] in
+    defm t.Ty# : I3Inst<op_str # t.Size, op_node, t, commutative>;
 }
 
 class I16x2<string OpcStr, SDNode OpNode> :
@@ -270,26 +277,11 @@ class I16x2<string OpcStr, SDNode OpNode> :
 
 // Template for instructions which take 3 int args.  The instructions are
 // named "<OpcStr>.s32" (e.g. "addc.cc.s32").
-multiclass ADD_SUB_INT_CARRY<string OpcStr, SDNode OpNode> {
+multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
   let hasSideEffects = 1 in {
-    def i32rr :
-      NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
-                !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
-                [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
-    def i32ri :
-      NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
-                !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
-                [(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
-    def i64rr :
-      NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
-                !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
-                [(set i64:$dst, (OpNode i64:$a, i64:$b))]>,
-      Requires<[hasPTX<43>]>;
-    def i64ri :
-      NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
-                !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
-                [(set i64:$dst, (OpNode i64:$a, imm:$b))]>,
-      Requires<[hasPTX<43>]>;
+    defm i32 : I3Inst<op_str # ".s32", op_node, I32RT, commutative>;
+    defm i64 : I3Inst<op_str # ".s64", op_node, I64RT, commutative,
+                     requires = [hasPTX<43>]>;
   }
 }
 
@@ -847,12 +839,12 @@ defm SUB : I3<"sub.s", sub, /*commutative=*/ false>;
 def ADD16x2 : I16x2<"add.s", add>;
 
 // in32 and int64 addition and subtraction with carry-out.
-defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc>;
-defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc>;
+defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
+defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
 
 // int32 and int64 addition and subtraction with carry-in and carry-out.
-defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde>;
-defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube>;
+defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde, commutative = true>;
+defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube, commutative = false>;
 
 defm MULT : I3<"mul.lo.s", mul, /*commutative=*/ true>;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 3eedb43e4c81a..12c886cb0ca4e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1678,6 +1678,21 @@ foreach t = [I32RT, I64RT] in {
   }
 }
 
+//
+// szext
+//
+
+foreach sign = ["s", "u"] in {
+  foreach mode = ["wrap", "clamp"] in {
+    defvar ext = !if(!eq(sign, "s"), "sext", "zext");
+    defvar intrin = !cast<Intrinsic>("int_nvvm_" # ext # "_inreg_" # mode);
+    defm SZEXT_ # sign # _ # mode
+      : I3Inst<"szext." # mode # "." # sign # "32",
+               intrin, I32RT, commutative = false,
+               requires = [hasSM<70>, hasPTX<76>]>;
+  }
+}
+
 //
 // Convert
 //
diff --git a/llvm/test/CodeGen/NVPTX/szext.ll b/llvm/test/CodeGen/NVPTX/szext.ll
new file mode 100644
index 0000000000000..a86c06c24ed98
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/szext.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -o - < %s -mcpu=sm_70 -mattr=+ptx76 | FileCheck %s
+
+target triple = "nvptx-unknown-cuda"
+
+define i32 @szext_wrap_u32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_wrap_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [szext_wrap_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [szext_wrap_u32_param_1];
+; CHECK-NEXT:    szext.wrap.u32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %c = call i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
+  ret i32 %c
+}
+
+define i32 @szext_clamp_u32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_clamp_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [szext_clamp_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [szext_clamp_u32_param_1];
+; CHECK-NEXT:    szext.clamp.u32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %c = call i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
+  ret i32 %c
+}
+
+define i32 @szext_wrap_s32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_wrap_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [szext_wrap_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [szext_wrap_s32_param_1];
+; CHECK-NEXT:    szext.wrap.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %c = call i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
+  ret i32 %c
+}
+
+define i32 @szext_clamp_s32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_clamp_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [szext_clamp_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [szext_clamp_s32_param_1];
+; CHECK-NEXT:    szext.clamp.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %c = call i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
+  ret i32 %c
+}
+

@@ -568,6 +568,99 @@ to left-shift the found bit into the most-significant bit position, otherwise
the result is the shift amount needed to right-shift the found bit into the
least-significant bit position. 0xffffffff is returned if no 1 bit is found.

'``llvm.nvvm.zext.inreg.clamp``' Intrinsic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the inreg part?

It sounds like an implementation detail at best, and misleading, at worst. I.e. I'd assume that it implies in-place (i.e. in the same register) conversion, which is not the case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it. The intent was to match the convention of the ISD::SEXT_INREG node which performs an operation similar to this if %b were a constant.

Comment on lines 660 to 662
The '``llvm.nvvm.sext.inreg.wrap``' returns the sign-extension of N lowest bits
of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
is 0.
Copy link
Member

@Artem-B Artem-B May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of unnecessary redundancy in the per-variant descriptions.
I'd just combine them into the list of intrinsics, and description of sext/zext and wrap/clamp variants, as they are orthogonal. Sort of similar to how PTX spec describes the instruction itself: this initialization time could be even longer if the CUDA driver had been unloaded due to GPU inactivity

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I've consolidated into a single entry.

this initialization time could be even longer if the CUDA driver had been unloaded due to GPU inactivity

I'm confused by this bit, not sure if this was intended to be included here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh. Apparently the link to the docs didn't get copied and I've pasted previously copied stuff from somewhere else. Fixed now.

; CHECK-NEXT: szext.wrap.u32 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%c = call i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not testing immediate arguments.

I think we will have trouble if both args are immediates. LLVM will likely instcombine those away for add/sub, but it can't do so for intrinsics, and the I3Inst does not seem to have ii variant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ii case is already well handled. During Isel we match first on the intrinsic, converting it to the ir variant. Then the other immediate operand is matched against and converted to a mov. I've added a couple tests to demonstrate that all immediate combinations work.

; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: mov.b32 %r1, 3;
; CHECK-NEXT: szext.clamp.s32 %r2, %r1, 4;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both arguments must be immediate values for ii

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both arguments are immediate values in the LLVM IR. This lowering seems correct to me.

Copy link
Member

@Artem-B Artem-B May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that we're still missing the ii instruction variants. LLVM is smart enough to pass the immediate via a register. So, I agree that the lowering is working correctly. Thie new tests do address my original concern.

We do have that extra mov.b32 %r1, 3; now, but it's a cosmetic issue. I'd add ii variant of the instruction, but that's a nit. We have other places where we're not passing immediates as instruction arguments.

Up to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general, all immediates for pure functions is a case we don't add patterns for. These cases should be pretty rare, and should be handled by constant-folding. Adding code to ISel will likely hurt compile-time on the margin without any clear benefit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, in general, for pure functions that compiler knows and can reason about, which generally means that it can constant-fold them.

For intrinsics it's less clear cut, as they tend to be special cases compiler can't do much about. Those constant parameters may or may not allow the intrinsic to be const-folded, and even when it can, it would still have to be implemented as a special case. E.g. for szext instructions, we could get LLVM to calculate the result of the operation if both constants are known, but we currently don't.

An extra move will be dealt with by ptxas, so I'm OK without ii lowering.

; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: mov.b32 %r1, 3;
; CHECK-NEXT: szext.clamp.s32 %r2, %r1, 4;
Copy link
Member

@Artem-B Artem-B May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that we're still missing the ii instruction variants. LLVM is smart enough to pass the immediate via a register. So, I agree that the lowering is working correctly. Thie new tests do address my original concern.

We do have that extra mov.b32 %r1, 3; now, but it's a cosmetic issue. I'd add ii variant of the instruction, but that's a nit. We have other places where we're not passing immediates as instruction arguments.

Up to you.

@AlexMaclean AlexMaclean merged commit 802a2e3 into llvm:main May 9, 2025
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants