-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[NVPTX] Add intrinsics for the szext instruction #139126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Add intrinsics for the szext instruction #139126
Conversation
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-llvm-ir Author: Alex MacLean (AlexMaclean) ChangesThis change adds support for Full diff: https://github.com/llvm/llvm-project/pull/139126.diff 5 Files Affected:
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index c1426823d87af..331a4b8e08883 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -568,6 +568,99 @@ to left-shift the found bit into the most-significant bit position, otherwise
the result is the shift amount needed to right-shift the found bit into the
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
+'``llvm.nvvm.zext.inreg.clamp``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.zext.inreg.clamp``' intrinsic extracts the low bits of the
+input value, and zero-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.zext.inreg.clamp``' returns the zero-extension of N lowest bits
+of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
+is 0, the result is 0.
+
+'``llvm.nvvm.zext.inreg.wrap``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.zext.inreg.wrap``' intrinsic extracts the low bits of the
+input value, and zero-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.zext.inreg.wrap``' returns the zero-extension of N lowest bits
+of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
+is 0.
+
+'``llvm.nvvm.sext.inreg.clamp``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.sext.inreg.clamp``' intrinsic extracts the low bits of the
+input value, and sign-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.sext.inreg.clamp``' returns the sign-extension of N lowest bits
+of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
+is 0, the result is 0.
+
+
+'``llvm.nvvm.sext.inreg.wrap``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.sext.inreg.wrap``' intrinsic extracts the low bits of the
+input value, and sign-extends them back to the original width.
+
+Semantics:
+""""""""""
+
+The '``llvm.nvvm.sext.inreg.wrap``' returns the sign-extension of N lowest bits
+of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
+is 0.
+
TMA family of Intrinsics
------------------------
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 8b87822d3fdda..65f0e2209fc6b 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1356,6 +1356,17 @@ let TargetPrefix = "nvvm" in {
[llvm_anyint_ty, llvm_i1_ty],
[IntrNoMem, IntrSpeculatable, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
+
+//
+// szext
+//
+ foreach ext = ["sext", "zext"] in
+ foreach mode = ["wrap", "clamp"] in
+ def int_nvvm_ # ext # _inreg_ # mode :
+ DefaultAttrsIntrinsic<[llvm_i32_ty],
+ [llvm_i32_ty, llvm_i32_ty],
+ [IntrNoMem, IntrSpeculatable]>;
+
//
// Convert
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 11d77599d4ac3..dae6c929eea9e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -240,26 +240,33 @@ def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
+multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
+ bit commutative, list<Predicate> requires = []> {
+ defvar asmstr = op_str # " \t$dst, $a, $b;";
+
+ def rr :
+ NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
+ asmstr,
+ [(set t.Ty:$dst, (op_node t.Ty:$a, t.Ty:$b))]>,
+ Requires<requires>;
+ def ri :
+ NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
+ asmstr,
+ [(set t.Ty:$dst, (op_node t.RC:$a, imm:$b))]>,
+ Requires<requires>;
+ if !not(commutative) then
+ def ir :
+ NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
+ asmstr,
+ [(set t.Ty:$dst, (op_node imm:$a, t.RC:$b))]>,
+ Requires<requires>;
+}
+
// Template for instructions which take three int64, int32, or int16 args.
// The instructions are named "<OpcStr><Width>" (e.g. "add.s64").
-multiclass I3<string OpcStr, SDNode OpNode, bit commutative> {
- foreach t = [I16RT, I32RT, I64RT] in {
- defvar asmstr = OpcStr # t.Size # " \t$dst, $a, $b;";
-
- def t.Ty # rr :
- NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
- asmstr,
- [(set t.Ty:$dst, (OpNode t.Ty:$a, t.Ty:$b))]>;
- def t.Ty # ri :
- NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
- asmstr,
- [(set t.Ty:$dst, (OpNode t.RC:$a, imm:$b))]>;
- if !not(commutative) then
- def t.Ty # ir :
- NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
- asmstr,
- [(set t.Ty:$dst, (OpNode imm:$a, t.RC:$b))]>;
- }
+multiclass I3<string op_str, SDPatternOperator op_node, bit commutative> {
+ foreach t = [I16RT, I32RT, I64RT] in
+ defm t.Ty# : I3Inst<op_str # t.Size, op_node, t, commutative>;
}
class I16x2<string OpcStr, SDNode OpNode> :
@@ -270,26 +277,11 @@ class I16x2<string OpcStr, SDNode OpNode> :
// Template for instructions which take 3 int args. The instructions are
// named "<OpcStr>.s32" (e.g. "addc.cc.s32").
-multiclass ADD_SUB_INT_CARRY<string OpcStr, SDNode OpNode> {
+multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
let hasSideEffects = 1 in {
- def i32rr :
- NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
- !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
- [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
- def i32ri :
- NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
- !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
- [(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
- def i64rr :
- NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
- !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
- [(set i64:$dst, (OpNode i64:$a, i64:$b))]>,
- Requires<[hasPTX<43>]>;
- def i64ri :
- NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
- !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
- [(set i64:$dst, (OpNode i64:$a, imm:$b))]>,
- Requires<[hasPTX<43>]>;
+ defm i32 : I3Inst<op_str # ".s32", op_node, I32RT, commutative>;
+ defm i64 : I3Inst<op_str # ".s64", op_node, I64RT, commutative,
+ requires = [hasPTX<43>]>;
}
}
@@ -847,12 +839,12 @@ defm SUB : I3<"sub.s", sub, /*commutative=*/ false>;
def ADD16x2 : I16x2<"add.s", add>;
// in32 and int64 addition and subtraction with carry-out.
-defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc>;
-defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc>;
+defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
+defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
// int32 and int64 addition and subtraction with carry-in and carry-out.
-defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde>;
-defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube>;
+defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde, commutative = true>;
+defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube, commutative = false>;
defm MULT : I3<"mul.lo.s", mul, /*commutative=*/ true>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 3eedb43e4c81a..12c886cb0ca4e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1678,6 +1678,21 @@ foreach t = [I32RT, I64RT] in {
}
}
+//
+// szext
+//
+
+foreach sign = ["s", "u"] in {
+ foreach mode = ["wrap", "clamp"] in {
+ defvar ext = !if(!eq(sign, "s"), "sext", "zext");
+ defvar intrin = !cast<Intrinsic>("int_nvvm_" # ext # "_inreg_" # mode);
+ defm SZEXT_ # sign # _ # mode
+ : I3Inst<"szext." # mode # "." # sign # "32",
+ intrin, I32RT, commutative = false,
+ requires = [hasSM<70>, hasPTX<76>]>;
+ }
+}
+
//
// Convert
//
diff --git a/llvm/test/CodeGen/NVPTX/szext.ll b/llvm/test/CodeGen/NVPTX/szext.ll
new file mode 100644
index 0000000000000..a86c06c24ed98
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/szext.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -o - < %s -mcpu=sm_70 -mattr=+ptx76 | FileCheck %s
+
+target triple = "nvptx-unknown-cuda"
+
+define i32 @szext_wrap_u32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_wrap_u32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [szext_wrap_u32_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [szext_wrap_u32_param_1];
+; CHECK-NEXT: szext.wrap.u32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %c = call i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
+ ret i32 %c
+}
+
+define i32 @szext_clamp_u32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_clamp_u32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [szext_clamp_u32_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [szext_clamp_u32_param_1];
+; CHECK-NEXT: szext.clamp.u32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %c = call i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
+ ret i32 %c
+}
+
+define i32 @szext_wrap_s32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_wrap_s32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [szext_wrap_s32_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [szext_wrap_s32_param_1];
+; CHECK-NEXT: szext.wrap.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %c = call i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
+ ret i32 %c
+}
+
+define i32 @szext_clamp_s32(i32 %a, i32 %b) {
+; CHECK-LABEL: szext_clamp_s32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [szext_clamp_s32_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [szext_clamp_s32_param_1];
+; CHECK-NEXT: szext.clamp.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %c = call i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
+ ret i32 %c
+}
+
|
llvm/docs/NVPTXUsage.rst
Outdated
@@ -568,6 +568,99 @@ to left-shift the found bit into the most-significant bit position, otherwise | |||
the result is the shift amount needed to right-shift the found bit into the | |||
least-significant bit position. 0xffffffff is returned if no 1 bit is found. | |||
|
|||
'``llvm.nvvm.zext.inreg.clamp``' Intrinsic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the inreg
part?
It sounds like an implementation detail at best, and misleading, at worst. I.e. I'd assume that it implies in-place (i.e. in the same register) conversion, which is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove it. The intent was to match the convention of the ISD::SEXT_INREG
node which performs an operation similar to this if %b
were a constant.
llvm/docs/NVPTXUsage.rst
Outdated
The '``llvm.nvvm.sext.inreg.wrap``' returns the sign-extension of N lowest bits | ||
of operand %a. N is the value of operand %b modulo 32. If N is 0, the result | ||
is 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of unnecessary redundancy in the per-variant descriptions.
I'd just combine them into the list of intrinsics, and description of sext/zext and wrap/clamp variants, as they are orthogonal. Sort of similar to how PTX spec describes the instruction itself: this initialization time could be even longer if the CUDA driver had been unloaded due to GPU inactivity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I've consolidated into a single entry.
this initialization time could be even longer if the CUDA driver had been unloaded due to GPU inactivity
I'm confused by this bit, not sure if this was intended to be included here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh. Apparently the link to the docs didn't get copied and I've pasted previously copied stuff from somewhere else. Fixed now.
llvm/test/CodeGen/NVPTX/szext.ll
Outdated
; CHECK-NEXT: szext.wrap.u32 %r3, %r1, %r2; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r3; | ||
; CHECK-NEXT: ret; | ||
%c = call i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not testing immediate arguments.
I think we will have trouble if both args are immediates. LLVM will likely instcombine those away for add/sub, but it can't do so for intrinsics, and the I3Inst
does not seem to have ii
variant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ii
case is already well handled. During Isel we match first on the intrinsic, converting it to the ir
variant. Then the other immediate operand is matched against and converted to a mov
. I've added a couple tests to demonstrate that all immediate combinations work.
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: mov.b32 %r1, 3; | ||
; CHECK-NEXT: szext.clamp.s32 %r2, %r1, 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both arguments must be immediate values for ii
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both arguments are immediate values in the LLVM IR. This lowering seems correct to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that we're still missing the ii
instruction variants. LLVM is smart enough to pass the immediate via a register. So, I agree that the lowering is working correctly. Thie new tests do address my original concern.
We do have that extra mov.b32 %r1, 3;
now, but it's a cosmetic issue. I'd add ii
variant of the instruction, but that's a nit. We have other places where we're not passing immediates as instruction arguments.
Up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general, all immediates for pure functions is a case we don't add patterns for. These cases should be pretty rare, and should be handled by constant-folding. Adding code to ISel will likely hurt compile-time on the margin without any clear benefit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, in general, for pure functions that compiler knows and can reason about, which generally means that it can constant-fold them.
For intrinsics it's less clear cut, as they tend to be special cases compiler can't do much about. Those constant parameters may or may not allow the intrinsic to be const-folded, and even when it can, it would still have to be implemented as a special case. E.g. for szext instructions, we could get LLVM to calculate the result of the operation if both constants are known, but we currently don't.
An extra move will be dealt with by ptxas, so I'm OK without ii
lowering.
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: mov.b32 %r1, 3; | ||
; CHECK-NEXT: szext.clamp.s32 %r2, %r1, 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that we're still missing the ii
instruction variants. LLVM is smart enough to pass the immediate via a register. So, I agree that the lowering is working correctly. Thie new tests do address my original concern.
We do have that extra mov.b32 %r1, 3;
now, but it's a cosmetic issue. I'd add ii
variant of the instruction, but that's a nit. We have other places where we're not passing immediates as instruction arguments.
Up to you.
This change adds support for
llvm.nvvm.{sext,zext}.inreg.{wrap,clamp}
intrinsics.