-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+/ptx88+ #139292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-llvm-transforms Author: Drew Kersnar (dakersnar) ChangesPTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations. See the spec for the three relevant PTX instructions here:
Patch is 183.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139292.diff 15 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 0b137250e4e59..ab1c3c19168af 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -319,6 +319,9 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
case NVPTX::PTXLdStInstCode::V4:
O << ".v4";
return;
+ case NVPTX::PTXLdStInstCode::V8:
+ O << ".v8";
+ return;
}
// TODO: evaluate whether cases not covered by this switch are bugs
return;
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 83090ab720c73..2468b8f43ae94 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -199,7 +199,8 @@ enum FromType {
enum VecType {
Scalar = 1,
V2 = 2,
- V4 = 4
+ V4 = 4,
+ V8 = 8
};
} // namespace PTXLdStInstCode
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 6f6084b99dda2..74594837d92cc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
return;
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
+ case NVPTXISD::LoadV8:
if (tryLoadVector(N))
return;
break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
break;
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
+ case NVPTXISD::StoreV8:
if (tryStoreVector(N))
return;
break;
@@ -1195,6 +1197,12 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
+ case NVPTXISD::LoadV8:
+ if (!Subtarget->has256BitMaskedLoadStore())
+ return false;
+ FromTypeWidth = TotalWidth / 8;
+ VecType = NVPTX::PTXLdStInstCode::V8;
+ break;
default:
return false;
}
@@ -1205,7 +1213,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
}
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
- FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
+ FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
SDValue Offset, Base;
SelectADDR(N->getOperand(1), Base, Offset);
@@ -1230,9 +1238,22 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
break;
case NVPTXISD::LoadV4:
- Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
- NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
- NVPTX::LDV_f32_v4, std::nullopt);
+ Opcode =
+ pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+ NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
+ NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
+ break;
+ case NVPTXISD::LoadV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::LDV_i32_v8;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::LDV_f32_v8;
+ break;
+ default:
+ return false;
+ }
break;
}
if (!Opcode)
@@ -1328,7 +1349,8 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
Opcode = pickOpcodeForVT(
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
- std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+ NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
+ NVPTX::INT_PTX_LDG_G_v4f64_ELE);
break;
case NVPTXISD::LDUV4:
Opcode = pickOpcodeForVT(
@@ -1336,6 +1358,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
break;
+ case NVPTXISD::LoadV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8f32_ELE;
+ break;
+ case MVT::v2i16:
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ case MVT::v4i8:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+ break;
+ default:
+ return false;
+ }
+ break;
}
if (!Opcode)
return false;
@@ -1502,6 +1542,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
N2 = N->getOperand(5);
ToTypeWidth = TotalWidth / 4;
break;
+ case NVPTXISD::StoreV8:
+ if (!Subtarget->has256BitMaskedLoadStore())
+ return false;
+ VecType = NVPTX::PTXLdStInstCode::V8;
+ Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
+ N->getOperand(4), N->getOperand(5), N->getOperand(6),
+ N->getOperand(7), N->getOperand(8)});
+ N2 = N->getOperand(9);
+ ToTypeWidth = TotalWidth / 8;
+ break;
default:
return false;
}
@@ -1512,7 +1562,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
}
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
- TotalWidth <= 128 && "Invalid width for store");
+ TotalWidth <= 256 && "Invalid width for store");
SDValue Offset, Base;
SelectADDR(N2, Base, Offset);
@@ -1533,9 +1583,22 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
break;
case NVPTXISD::StoreV4:
- Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
- NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
- NVPTX::STV_f32_v4, std::nullopt);
+ Opcode =
+ pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
+ NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
+ NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
+ break;
+ case NVPTXISD::StoreV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::STV_i32_v8;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::STV_f32_v8;
+ break;
+ default:
+ return false;
+ }
break;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3769aae7b620f..d7883b5d526aa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
case MVT::v2f32:
case MVT::v4f32:
case MVT::v2f64:
+ case MVT::v4i64:
+ case MVT::v4f64:
+ case MVT::v8i32:
+ case MVT::v8f32:
+ case MVT::v16f16: // <8 x f16x2>
+ case MVT::v16bf16: // <8 x bf16x2>
+ case MVT::v16i16: // <8 x i16x2>
+ case MVT::v32i8: // <8 x i8x4>
return true;
}
}
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
// - unsigned int NumElts - The number of elements in the final vector
// - EVT EltVT - The type of the elements in the final vector
static std::optional<std::pair<unsigned int, MVT>>
-getVectorLoweringShape(EVT VectorEVT) {
+getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
if (!VectorEVT.isSimple())
return std::nullopt;
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
switch (VectorVT.SimpleTy) {
default:
return std::nullopt;
+ case MVT::v4i64:
+ case MVT::v4f64:
+ case MVT::v8i32:
+ case MVT::v8f32:
+ // This is a "native" vector type iff the address space is global
+ // and the target supports 256-bit loads/stores
+ if (!CanLowerTo256Bit)
+ return std::nullopt;
+ LLVM_FALLTHROUGH;
case MVT::v2i8:
case MVT::v2i16:
case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
case MVT::v4f32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);
+ case MVT::v16f16: // <8 x f16x2>
+ case MVT::v16bf16: // <8 x bf16x2>
+ case MVT::v16i16: // <8 x i16x2>
+ case MVT::v32i8: // <8 x i8x4>
+ // This can be upsized into a "native" vector type iff the address space is
+ // global and the target supports 256-bit loads/stores.
+ if (!CanLowerTo256Bit)
+ return std::nullopt;
+ LLVM_FALLTHROUGH;
case MVT::v8i8: // <2 x i8x4>
case MVT::v8f16: // <4 x f16x2>
case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::ProxyReg)
MAKE_CASE(NVPTXISD::LoadV2)
MAKE_CASE(NVPTXISD::LoadV4)
+ MAKE_CASE(NVPTXISD::LoadV8)
MAKE_CASE(NVPTXISD::LDUV2)
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
MAKE_CASE(NVPTXISD::StoreV4)
+ MAKE_CASE(NVPTXISD::StoreV8)
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
MAKE_CASE(NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
if (ValVT != MemVT)
return SDValue();
- const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
+ // 256-bit vectors are only allowed iff the address is global
+ // and the target supports 256-bit loads/stores
+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+ bool CanLowerTo256Bit =
+ AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore();
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT, CanLowerTo256Bit);
if (!NumEltsAndEltVT)
return SDValue();
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case 4:
Opcode = NVPTXISD::StoreV4;
break;
+ case 8:
+ Opcode = NVPTXISD::StoreV8;
+ break;
}
SmallVector<SDValue, 8> Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
- SmallVectorImpl<SDValue> &Results) {
+ SmallVectorImpl<SDValue> &Results,
+ bool TargetHas256BitVectorLoadStore) {
LoadSDNode *LD = cast<LoadSDNode>(N);
const EVT ResVT = LD->getValueType(0);
const EVT MemVT = LD->getMemoryVT();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
if (ResVT != MemVT)
return;
- const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+ // 256-bit vectors are only allowed iff the address is global
+ // and the target supports 256-bit loads/stores
+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+ bool CanLowerTo256Bit =
+ AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT, CanLowerTo256Bit);
if (!NumEltsAndEltVT)
return;
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
break;
}
+ case 8: {
+ Opcode = NVPTXISD::LoadV8;
+ EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
+ LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
+ LdResVTs = DAG.getVTList(ListVTs);
+ break;
+ }
}
SDLoc DL(LD);
@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
- ReplaceLoadVector(N, DAG, Results);
+ ReplaceLoadVector(N, DAG, Results, STI.has256BitMaskedLoadStore());
return;
case ISD::INTRINSIC_W_CHAIN:
ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3dff83d74538b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -84,10 +84,12 @@ enum NodeType : unsigned {
FIRST_MEMORY_OPCODE,
LoadV2 = FIRST_MEMORY_OPCODE,
LoadV4,
+ LoadV8,
LDUV2, // LDU.v2
LDUV4, // LDU.v4
StoreV2,
StoreV4,
+ StoreV8,
LoadParam,
LoadParamV2,
LoadParamV4,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a384cb79d645a..d0f3fb4ec1c1d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2425,7 +2425,7 @@ let mayStore=1, hasSideEffects=0 in {
// The following is used only in and after vector elementizations. Vector
// elementization happens at the machine instruction level, so the following
// instructions never appear in the DAG.
-multiclass LD_VEC<NVPTXRegClass regclass> {
+multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
@@ -2438,17 +2438,27 @@ multiclass LD_VEC<NVPTXRegClass regclass> {
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
+ if support_v8 then {
+ def _v8 : NVPTXInst<
+ (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+ regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+ (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+ i32imm:$fromWidth, ADDR:$addr),
+ "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+ "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
+ "[$addr];", []>;
+ }
}
let mayLoad=1, hasSideEffects=0 in {
defm LDV_i8 : LD_VEC<Int16Regs>;
defm LDV_i16 : LD_VEC<Int16Regs>;
- defm LDV_i32 : LD_VEC<Int32Regs>;
+ defm LDV_i32 : LD_VEC<Int32Regs, true>;
defm LDV_i64 : LD_VEC<Int64Regs>;
- defm LDV_f32 : LD_VEC<Float32Regs>;
+ defm LDV_f32 : LD_VEC<Float32Regs, true>;
defm LDV_f64 : LD_VEC<Float64Regs>;
}
-multiclass ST_VEC<NVPTXRegClass regclass> {
+multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
@@ -2463,14 +2473,25 @@ multiclass ST_VEC<NVPTXRegClass regclass> {
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
+ if support_v8 then {
+ def _v8 : NVPTXInst<
+ (outs),
+ (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
+ regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
+ LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+ i32imm:$fromWidth, ADDR:$addr),
+ "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+ "\t[$addr], "
+ "{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
+ }
}
let mayStore=1, hasSideEffects=0 in {
defm STV_i8 : ST_VEC<Int16Regs>;
defm STV_i16 : ST_VEC<Int16Regs>;
- defm STV_i32 : ST_VEC<Int32Regs>;
+ defm STV_i32 : ST_VEC<Int32Regs, true>;
defm STV_i64 : ST_VEC<Int64Regs>;
- defm STV_f32 : ST_VEC<Float32Regs>;
+ defm STV_f32 : ST_VEC<Float32Regs, true>;
defm STV_f64 : ST_VEC<Float64Regs>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 7b139d7b79e7d..cdbf29c140429 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2388,6 +2388,12 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
(ins ADDR:$src),
"ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
+class VLDG_G_ELE_V8<string TyStr, NVPTXRegClass regclass> :
+ NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+ regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+ (ins ADDR:$src),
+ "ld.global.nc.v8." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
+
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
@@ -2401,6 +2407,10 @@ def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
+def INT_PTX_LDG_G_v4i64_ELE : VLDG_G_ELE_V4<"u64", Int64Regs>;
+def INT_PTX_LDG_G_v4f64_ELE : VLDG_G_ELE_V4<"f64", Float64Regs>;
+def INT_PTX_LDG_G_v8i32_ELE : VLDG_G_ELE_V8<"u32", Int32Regs>;
+def INT_PTX_LDG_G_v8f32_ELE : VLDG_G_ELE_V8<"f32", Float32Regs>;
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 0a4fc8d1435be..5552bba728160 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -72,6 +72,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
const SelectionDAGTargetInfo *getSelectionDAGInfo() const override;
+ bool has256BitMaskedLoadStore() const {
+ return SmVersion >= 100 && PTXVersion >= 88;
+ }
bool hasAtomAddF64() const { return SmVersion >= 60; }
bool hasAtomScope() const { return SmVersion >= 60; }
bool hasAtomBitwise64() const { return SmVersion >= 32; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 66c5139f8c2cc..1d8525fd4656f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -591,6 +591,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
+ // 256 bit loads/stores are currently only supported for global address space
+ if (AddrSpace == ADDRESS_SPACE_GLOBAL && ST->has256BitMaskedLoadStore())
+ return 256;
+ return 128;
+}
+
unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
if (isa<AllocaInst>(V))
return ADDRESS_SPACE_LOCAL;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index a9bd5a0d01043..98aea4e535f0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -173,6 +173,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
+
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
Value *NewV) const override;
unsigned getAssumedAddrSpace(const Value *V) const override;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
new file mode 100644
index 0000000000000..f4abcb37aa894
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -0,0 +1,520 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 -verify-machineinstrs | FileCheck %s -check-prefixes=SM90
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 -verify-machineinstrs | FileCheck %s -check-prefixes=SM100
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; For 256-bit vectors, check that invariant loads from the
+; global addrspace are lowered to ld.globa...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Do we want the above formatting change? I figured it was laid out the way it is for a reason, but maybe that reason is just "it was written before we started checking formatting" 😀. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
I think I addressed all current feedback, although the removal of the VecType operand is causing some tests to fail, I'll need to investigate what I'm missing on Monday but saving my progress for now. Also, @AlexMaclean do you know if NVPTXForwardParams.cpp need to be updated to include the v8 opcodes? |
Unfortunately, we have lots of brittle code which uses hard-coded operand indices for various machine instructions. Here is one example I personally am guilty of introducing: llvm-project/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp Lines 108 to 109 in 2da57f8
There are likely a few other places where this occurs (not all of which are as nicely called out), which you'll need to update now that the load and store instructions have one fewer operand. You might want to look at when the last time an operand was added what changed.
I don't think that would make sense. This pass is about replacing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Have you thought about how hard it would be to fix this problem across the codebase? I'd agree that the hard-coded operand indices seem to be one of the more dangerous patterns I've noticed in our code, but I also wouldn't be surprised if it would be too much churn to fix. I'm curious, does anyone know if other LLVM backends (especially ones that use the selectiondag framework fully as intended) manage to avoid this issue, or is it a flaw baked into the framework? Edit: after investigating I can see why we are in no position to solve this codebase-wide... regardless, I only needed to update 3 places. Could be worse. |
369a9cf
to
7bc66d0
Compare
This should be ready to merge, assuming changes since the last review are ok. I don't have merge permissions yet (I think I can get them after this PR merges, since this is my third), so if somebody could merge this, ideally using my nvidia.com email if it lets you select it, that would be great. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a tiny nit.
PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations.
See the spec for the three relevant PTX instructions here: