Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+/ptx88+ #139292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 13, 2025

Conversation

dakersnar
Copy link
Contributor

PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations.

See the spec for the three relevant PTX instructions here:

@llvmbot
Copy link
Member

llvmbot commented May 9, 2025

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-transforms

Author: Drew Kersnar (dakersnar)

Changes

PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations.

See the spec for the three relevant PTX instructions here:


Patch is 183.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139292.diff

15 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3)
  • (modified) llvm/lib/Target/NVPTX/NVPTX.h (+2-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+72-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+54-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+27-6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (+7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2)
  • (added) llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll (+520)
  • (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing-invariant.ll (+549)
  • (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing.ll (+543)
  • (added) llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll (+1442)
  • (added) llvm/test/Transforms/LoadStoreVectorizer/NVPTX/256-bit.ll (+728)
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 0b137250e4e59..ab1c3c19168af 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -319,6 +319,9 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
     case NVPTX::PTXLdStInstCode::V4:
       O << ".v4";
       return;
+    case NVPTX::PTXLdStInstCode::V8:
+      O << ".v8";
+      return;
     }
     // TODO: evaluate whether cases not covered by this switch are bugs
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 83090ab720c73..2468b8f43ae94 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -199,7 +199,8 @@ enum FromType {
 enum VecType {
   Scalar = 1,
   V2 = 2,
-  V4 = 4
+  V4 = 4,
+  V8 = 8
 };
 } // namespace PTXLdStInstCode
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 6f6084b99dda2..74594837d92cc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     return;
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
+  case NVPTXISD::LoadV8:
     if (tryLoadVector(N))
       return;
     break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     break;
   case NVPTXISD::StoreV2:
   case NVPTXISD::StoreV4:
+  case NVPTXISD::StoreV8:
     if (tryStoreVector(N))
       return;
     break;
@@ -1195,6 +1197,12 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     FromTypeWidth = TotalWidth / 4;
     VecType = NVPTX::PTXLdStInstCode::V4;
     break;
+  case NVPTXISD::LoadV8:
+    if (!Subtarget->has256BitMaskedLoadStore())
+      return false;
+    FromTypeWidth = TotalWidth / 8;
+    VecType = NVPTX::PTXLdStInstCode::V8;
+    break;
   default:
     return false;
   }
@@ -1205,7 +1213,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   }
 
   assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
-         FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
+         FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
   SDValue Offset, Base;
   SelectADDR(N->getOperand(1), Base, Offset);
@@ -1230,9 +1238,22 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
                         NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
     break;
   case NVPTXISD::LoadV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
-                             NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
-                             NVPTX::LDV_f32_v4, std::nullopt);
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+                        NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
+                        NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
+    break;
+  case NVPTXISD::LoadV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::LDV_i32_v8;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::LDV_f32_v8;
+      break;
+    default:
+      return false;
+    }
     break;
   }
   if (!Opcode)
@@ -1328,7 +1349,8 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     Opcode = pickOpcodeForVT(
         EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
         NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
-        std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+        NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
+        NVPTX::INT_PTX_LDG_G_v4f64_ELE);
     break;
   case NVPTXISD::LDUV4:
     Opcode = pickOpcodeForVT(
@@ -1336,6 +1358,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
         NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
         std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
     break;
+  case NVPTXISD::LoadV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8f32_ELE;
+      break;
+    case MVT::v2i16:
+    case MVT::v2f16:
+    case MVT::v2bf16:
+    case MVT::v4i8:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+      break;
+    default:
+      return false;
+    }
+    break;
   }
   if (!Opcode)
     return false;
@@ -1502,6 +1542,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
     N2 = N->getOperand(5);
     ToTypeWidth = TotalWidth / 4;
     break;
+  case NVPTXISD::StoreV8:
+    if (!Subtarget->has256BitMaskedLoadStore())
+      return false;
+    VecType = NVPTX::PTXLdStInstCode::V8;
+    Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
+                N->getOperand(4), N->getOperand(5), N->getOperand(6),
+                N->getOperand(7), N->getOperand(8)});
+    N2 = N->getOperand(9);
+    ToTypeWidth = TotalWidth / 8;
+    break;
   default:
     return false;
   }
@@ -1512,7 +1562,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   }
 
   assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
-         TotalWidth <= 128 && "Invalid width for store");
+         TotalWidth <= 256 && "Invalid width for store");
 
   SDValue Offset, Base;
   SelectADDR(N2, Base, Offset);
@@ -1533,9 +1583,22 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
                         NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
     break;
   case NVPTXISD::StoreV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
-                             NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
-                             NVPTX::STV_f32_v4, std::nullopt);
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
+                        NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
+                        NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
+    break;
+  case NVPTXISD::StoreV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::STV_i32_v8;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::STV_f32_v8;
+      break;
+    default:
+      return false;
+    }
     break;
   }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3769aae7b620f..d7883b5d526aa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
   case MVT::v2f32:
   case MVT::v4f32:
   case MVT::v2f64:
+  case MVT::v4i64:
+  case MVT::v4f64:
+  case MVT::v8i32:
+  case MVT::v8f32:
+  case MVT::v16f16:  // <8 x f16x2>
+  case MVT::v16bf16: // <8 x bf16x2>
+  case MVT::v16i16:  // <8 x i16x2>
+  case MVT::v32i8:   // <8 x i8x4>
     return true;
   }
 }
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
 //    - unsigned int NumElts - The number of elements in the final vector
 //    - EVT EltVT - The type of the elements in the final vector
 static std::optional<std::pair<unsigned int, MVT>>
-getVectorLoweringShape(EVT VectorEVT) {
+getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
   if (!VectorEVT.isSimple())
     return std::nullopt;
   const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
   switch (VectorVT.SimpleTy) {
   default:
     return std::nullopt;
+  case MVT::v4i64:
+  case MVT::v4f64:
+  case MVT::v8i32:
+  case MVT::v8f32:
+    // This is a "native" vector type iff the address space is global
+    // and the target supports 256-bit loads/stores
+    if (!CanLowerTo256Bit)
+      return std::nullopt;
+    LLVM_FALLTHROUGH;
   case MVT::v2i8:
   case MVT::v2i16:
   case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
   case MVT::v4f32:
     // This is a "native" vector type
     return std::pair(NumElts, EltVT);
+  case MVT::v16f16:  // <8 x f16x2>
+  case MVT::v16bf16: // <8 x bf16x2>
+  case MVT::v16i16:  // <8 x i16x2>
+  case MVT::v32i8:   // <8 x i8x4>
+    // This can be upsized into a "native" vector type iff the address space is
+    // global and the target supports 256-bit loads/stores.
+    if (!CanLowerTo256Bit)
+      return std::nullopt;
+    LLVM_FALLTHROUGH;
   case MVT::v8i8:   // <2 x i8x4>
   case MVT::v8f16:  // <4 x f16x2>
   case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::ProxyReg)
     MAKE_CASE(NVPTXISD::LoadV2)
     MAKE_CASE(NVPTXISD::LoadV4)
+    MAKE_CASE(NVPTXISD::LoadV8)
     MAKE_CASE(NVPTXISD::LDUV2)
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)
     MAKE_CASE(NVPTXISD::StoreV4)
+    MAKE_CASE(NVPTXISD::StoreV8)
     MAKE_CASE(NVPTXISD::FSHL_CLAMP)
     MAKE_CASE(NVPTXISD::FSHR_CLAMP)
     MAKE_CASE(NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
   if (ValVT != MemVT)
     return SDValue();
 
-  const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
+  // 256-bit vectors are only allowed iff the address is global
+  // and the target supports 256-bit loads/stores
+  unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+  bool CanLowerTo256Bit =
+      AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore();
+  const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT, CanLowerTo256Bit);
   if (!NumEltsAndEltVT)
     return SDValue();
   const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
   case 4:
     Opcode = NVPTXISD::StoreV4;
     break;
+  case 8:
+    Opcode = NVPTXISD::StoreV8;
+    break;
   }
 
   SmallVector<SDValue, 8> Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
 
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
-                              SmallVectorImpl<SDValue> &Results) {
+                              SmallVectorImpl<SDValue> &Results,
+                              bool TargetHas256BitVectorLoadStore) {
   LoadSDNode *LD = cast<LoadSDNode>(N);
   const EVT ResVT = LD->getValueType(0);
   const EVT MemVT = LD->getMemoryVT();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
   if (ResVT != MemVT)
     return;
 
-  const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+  // 256-bit vectors are only allowed iff the address is global
+  // and the target supports 256-bit loads/stores
+  unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+  bool CanLowerTo256Bit =
+      AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
+  const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT, CanLowerTo256Bit);
   if (!NumEltsAndEltVT)
     return;
   const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
         DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
     break;
   }
+  case 8: {
+    Opcode = NVPTXISD::LoadV8;
+    EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
+                     LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
+    LdResVTs = DAG.getVTList(ListVTs);
+    break;
+  }
   }
   SDLoc DL(LD);
 
@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
     ReplaceBITCAST(N, DAG, Results);
     return;
   case ISD::LOAD:
-    ReplaceLoadVector(N, DAG, Results);
+    ReplaceLoadVector(N, DAG, Results, STI.has256BitMaskedLoadStore());
     return;
   case ISD::INTRINSIC_W_CHAIN:
     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3dff83d74538b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -84,10 +84,12 @@ enum NodeType : unsigned {
   FIRST_MEMORY_OPCODE,
   LoadV2 = FIRST_MEMORY_OPCODE,
   LoadV4,
+  LoadV8,
   LDUV2, // LDU.v2
   LDUV4, // LDU.v4
   StoreV2,
   StoreV4,
+  StoreV8,
   LoadParam,
   LoadParamV2,
   LoadParamV4,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a384cb79d645a..d0f3fb4ec1c1d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2425,7 +2425,7 @@ let mayStore=1, hasSideEffects=0 in {
 // The following is used only in and after vector elementizations.  Vector
 // elementization happens at the machine instruction level, so the following
 // instructions never appear in the DAG.
-multiclass LD_VEC<NVPTXRegClass regclass> {
+multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2),
     (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
@@ -2438,17 +2438,27 @@ multiclass LD_VEC<NVPTXRegClass regclass> {
          LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
     "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
+  if support_v8 then {
+    def _v8 : NVPTXInst<
+      (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+            regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+      (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+           i32imm:$fromWidth, ADDR:$addr),
+      "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+      "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
+      "[$addr];", []>;
+  }
 }
 let mayLoad=1, hasSideEffects=0 in {
   defm LDV_i8  : LD_VEC<Int16Regs>;
   defm LDV_i16 : LD_VEC<Int16Regs>;
-  defm LDV_i32 : LD_VEC<Int32Regs>;
+  defm LDV_i32 : LD_VEC<Int32Regs, true>;
   defm LDV_i64 : LD_VEC<Int64Regs>;
-  defm LDV_f32 : LD_VEC<Float32Regs>;
+  defm LDV_f32 : LD_VEC<Float32Regs, true>;
   defm LDV_f64 : LD_VEC<Float64Regs>;
 }
 
-multiclass ST_VEC<NVPTXRegClass regclass> {
+multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs),
     (ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
@@ -2463,14 +2473,25 @@ multiclass ST_VEC<NVPTXRegClass regclass> {
          LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
     "\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
+  if support_v8 then {
+    def _v8 : NVPTXInst<
+      (outs),
+      (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
+           regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
+           LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+           i32imm:$fromWidth, ADDR:$addr),
+      "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+      "\t[$addr], "
+      "{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
+  }
 }
 
 let mayStore=1, hasSideEffects=0 in {
   defm STV_i8  : ST_VEC<Int16Regs>;
   defm STV_i16 : ST_VEC<Int16Regs>;
-  defm STV_i32 : ST_VEC<Int32Regs>;
+  defm STV_i32 : ST_VEC<Int32Regs, true>;
   defm STV_i64 : ST_VEC<Int64Regs>;
-  defm STV_f32 : ST_VEC<Float32Regs>;
+  defm STV_f32 : ST_VEC<Float32Regs, true>;
   defm STV_f64 : ST_VEC<Float64Regs>;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 7b139d7b79e7d..cdbf29c140429 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2388,6 +2388,12 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
             (ins ADDR:$src),
             "ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
 
+class VLDG_G_ELE_V8<string TyStr, NVPTXRegClass regclass> :
+  NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+                  regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+             (ins ADDR:$src),
+             "ld.global.nc.v8." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
+
 // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
 def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
 def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
@@ -2401,6 +2407,10 @@ def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
 def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
 def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
 
+def INT_PTX_LDG_G_v4i64_ELE : VLDG_G_ELE_V4<"u64", Int64Regs>;
+def INT_PTX_LDG_G_v4f64_ELE : VLDG_G_ELE_V4<"f64", Float64Regs>;
+def INT_PTX_LDG_G_v8i32_ELE : VLDG_G_ELE_V8<"u32", Int32Regs>;
+def INT_PTX_LDG_G_v8f32_ELE : VLDG_G_ELE_V8<"f32", Float32Regs>;
 
 multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
   if Supports32 then
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 0a4fc8d1435be..5552bba728160 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -72,6 +72,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
 
   const SelectionDAGTargetInfo *getSelectionDAGInfo() const override;
 
+  bool has256BitMaskedLoadStore() const {
+    return SmVersion >= 100 && PTXVersion >= 88;
+  }
   bool hasAtomAddF64() const { return SmVersion >= 60; }
   bool hasAtomScope() const { return SmVersion >= 60; }
   bool hasAtomBitwise64() const { return SmVersion >= 32; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 66c5139f8c2cc..1d8525fd4656f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -591,6 +591,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
+  // 256 bit loads/stores are currently only supported for global address space
+  if (AddrSpace == ADDRESS_SPACE_GLOBAL && ST->has256BitMaskedLoadStore())
+    return 256;
+  return 128;
+}
+
 unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
   if (isa<AllocaInst>(V))
     return ADDRESS_SPACE_LOCAL;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index a9bd5a0d01043..98aea4e535f0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -173,6 +173,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
+
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
                                           Value *NewV) const override;
   unsigned getAssumedAddrSpace(const Value *V) const override;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
new file mode 100644
index 0000000000000..f4abcb37aa894
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -0,0 +1,520 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 -verify-machineinstrs | FileCheck %s -check-prefixes=SM90
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 -verify-machineinstrs | FileCheck %s -check-prefixes=SM100
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; For 256-bit vectors, check that invariant loads from the
+; global addrspace are lowered to ld.globa...
[truncated]

@AlexMaclean AlexMaclean requested review from Artem-B and AlexMaclean May 9, 2025 17:03
Copy link

github-actions bot commented May 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@dakersnar
Copy link
Contributor Author

Do we want the above formatting change? I figured it was laid out the way it is for a reason, but maybe that reason is just "it was written before we started checking formatting" 😀.

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@dakersnar
Copy link
Contributor Author

I think I addressed all current feedback, although the removal of the VecType operand is causing some tests to fail, I'll need to investigate what I'm missing on Monday but saving my progress for now.

Also, @AlexMaclean do you know if NVPTXForwardParams.cpp need to be updated to include the v8 opcodes?

@AlexMaclean
Copy link
Member

I think I addressed all current feedback, although the removal of the VecType operand is causing some tests to fail, I'll need to investigate what I'm missing on Monday but saving my progress for now.

Unfortunately, we have lots of brittle code which uses hard-coded operand indices for various machine instructions. Here is one example I personally am guilty of introducing:

constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;

There are likely a few other places where this occurs (not all of which are as nicely called out), which you'll need to update now that the load and store instructions have one fewer operand. You might want to look at when the last time an operand was added what changed.

Also, @AlexMaclean do you know if NVPTXForwardParams.cpp need to be updated to include the v8 opcodes?

I don't think that would make sense. This pass is about replacing ld.local with ld.param so I think with the current address-space restrictions it is impossible to ever encounter the v8 opcodes. If you can construct a test that demonstrates I'm incorrect adding v8 sounds good to me.

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dakersnar
Copy link
Contributor Author

dakersnar commented May 12, 2025

Unfortunately, we have lots of brittle code which uses hard-coded operand indices for various machine instructions. Here is one example I personally am guilty of introducing:

Have you thought about how hard it would be to fix this problem across the codebase? I'd agree that the hard-coded operand indices seem to be one of the more dangerous patterns I've noticed in our code, but I also wouldn't be surprised if it would be too much churn to fix.

I'm curious, does anyone know if other LLVM backends (especially ones that use the selectiondag framework fully as intended) manage to avoid this issue, or is it a flaw baked into the framework?

Edit: after investigating I can see why we are in no position to solve this codebase-wide... regardless, I only needed to update 3 places. Could be worse.

@dakersnar dakersnar force-pushed the github/dkersnar/256-bit branch from 369a9cf to 7bc66d0 Compare May 12, 2025 19:20
@dakersnar
Copy link
Contributor Author

dakersnar commented May 12, 2025

This should be ready to merge, assuming changes since the last review are ok. I don't have merge permissions yet (I think I can get them after this PR merges, since this is my third), so if somebody could merge this, ideally using my nvidia.com email if it lets you select it, that would be great.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a tiny nit.

@AlexMaclean AlexMaclean merged commit a1e1a84 into llvm:main May 13, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants