Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[OpenMP][Offload][Runtime] Add map types and map behaviour tweaks for descriptor and descriptor base address #138754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

agozillon
Copy link
Contributor

This PR seeks to add some new modifications to the OpenMP offload runtime for Fortran descriptor types with the behaviour being flipped on by new map type modifiers OMP_DESCRIPTOR and OMP_BASE_ADDR.

In particular it modifies the runtime to use shared memory for small descriptor allocations where feasible (with the size breakpoint being modifiable). It also tweaks some of the mapping behaviour around shadow pointer map back from device and always mapping for descriptors, in the latter if a descriptor has been marked as always, we skip over the data pointer (first 8-bytes) so as not to overwrite or otherwise modify the pointer, we leave all pointer control to the additional BASE_ADDR mapping. The former makes sure our host pointer address is not null before we map back shadow pointers for base addresses, with the intent to avoid re-allocation of deallocated host data from device data. This can occur when a large allocation of a derived type with multiple allocatable components has been mapped back in pieces and dealloacated, it's possible without these changes for the previously deallocated pieces to be reallocated (or at least the descriptor to be tricked into thinking it's allocated, breaking certain runtime functions around presence checking) by the shadow pointers in map backs of subsequent components.

Co-author: Matthew Curtis [email protected]

… descriptor and descriptor base address

This PR seeks to add some new modifications to the OpenMP offload runtime for Fortran descriptor types
with the behaviour being flipped on by new map type modifiers OMP_*_DESCRIPTOR and OMP_*_BASE_ADDR.

In particular it modifies the runtime to use shared memory for small descriptor allocations where
feasible (with the size breakpoint being modifiable). It also tweaks some of the mapping behaviour
around shadow pointer map back from device and always mapping for descriptors, in the latter if
a descriptor has been marked as always, we skip over the data pointer (first 8-bytes) so as not to
overwrite or otherwise modify the pointer, we leave all pointer control to the additional BASE_ADDR
mapping. The former makes sure our host pointer address is not null before we map back shadow pointers
for base addresses, with the intent to avoid re-allocation of deallocated host data from device data.
This can occur when a large allocation of a derived type with multiple allocatable components has been
mapped back in pieces and dealloacated, it's possible without these changes for the previously
deallocated pieces to be reallocated (or at least the descriptor to be tricked into thinking it's
allocated, breaking certain runtime functions around presence checking) by the shadow pointers in
map backs of subsequent components.

Co-author: Matthew Curtis <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-offload

@llvm/pr-subscribers-flang-openmp

Author: None (agozillon)

Changes

This PR seeks to add some new modifications to the OpenMP offload runtime for Fortran descriptor types with the behaviour being flipped on by new map type modifiers OMP_DESCRIPTOR and OMP_BASE_ADDR.

In particular it modifies the runtime to use shared memory for small descriptor allocations where feasible (with the size breakpoint being modifiable). It also tweaks some of the mapping behaviour around shadow pointer map back from device and always mapping for descriptors, in the latter if a descriptor has been marked as always, we skip over the data pointer (first 8-bytes) so as not to overwrite or otherwise modify the pointer, we leave all pointer control to the additional BASE_ADDR mapping. The former makes sure our host pointer address is not null before we map back shadow pointers for base addresses, with the intent to avoid re-allocation of deallocated host data from device data. This can occur when a large allocation of a derived type with multiple allocatable components has been mapped back in pieces and dealloacated, it's possible without these changes for the previously deallocated pieces to be reallocated (or at least the descriptor to be tricked into thinking it's allocated, breaking certain runtime functions around presence checking) by the shadow pointers in map backs of subsequent components.

Co-author: Matthew Curtis <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/138754.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPConstants.h (+4)
  • (modified) offload/include/OpenMP/Mapping.h (+11-8)
  • (modified) offload/include/omptarget.h (+4)
  • (modified) offload/libomptarget/OpenMP/Mapping.cpp (+66-26)
  • (modified) offload/libomptarget/PluginManager.cpp (+2-2)
  • (modified) offload/libomptarget/omptarget.cpp (+21-5)
  • (modified) offload/plugins-nextgen/amdgpu/src/rtl.cpp (+13-3)
  • (modified) offload/plugins-nextgen/common/include/PluginInterface.h (+8)
  • (modified) offload/plugins-nextgen/common/src/PluginInterface.cpp (+7)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
index 338b56226f204..9c203162bdcf8 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -236,6 +236,10 @@ enum class OpenMPOffloadMappingFlags : uint64_t {
   // dynamic.
   // This is an OpenMP extension for the sake of OpenACC support.
   OMP_MAP_OMPX_HOLD = 0x2000,
+  // Mapping is for a descriptor (a.k.a. dope vector)
+  OMP_MAP_DESCRIPTOR = 0x4000,
+  // Mapping is for a descriptor's (a.k.a. dope vector) data base address
+  OMP_MAP_DESCRIPTOR_BASE_ADDR = 0x8000,
   /// Signal that the runtime library should use args as an array of
   /// descriptor_dim pointers and use args_size as dims. Used when we have
   /// non-contiguous list items in target update directive
diff --git a/offload/include/OpenMP/Mapping.h b/offload/include/OpenMP/Mapping.h
index b9f5c16582931..50234d9fdd45d 100644
--- a/offload/include/OpenMP/Mapping.h
+++ b/offload/include/OpenMP/Mapping.h
@@ -52,6 +52,7 @@ struct ShadowPtrInfoTy {
   void *HstPtrVal = nullptr;
   void **TgtPtrAddr = nullptr;
   void *TgtPtrVal = nullptr;
+  bool IsDescriptorBaseAddr = false;
 
   bool operator==(const ShadowPtrInfoTy &Other) const {
     return HstPtrAddr == Other.HstPtrAddr;
@@ -68,6 +69,7 @@ struct HostDataToTargetTy {
   const uintptr_t HstPtrBegin;
   const uintptr_t HstPtrEnd;       // non-inclusive.
   const map_var_info_t HstPtrName; // Optional source name of mapped variable.
+  const int32_t AllocKind;
 
   const uintptr_t TgtAllocBegin; // allocated target memory
   const uintptr_t TgtPtrBegin; // mapped target memory = TgtAllocBegin + padding
@@ -124,10 +126,11 @@ struct HostDataToTargetTy {
 public:
   HostDataToTargetTy(uintptr_t BP, uintptr_t B, uintptr_t E,
                      uintptr_t TgtAllocBegin, uintptr_t TgtPtrBegin,
-                     bool UseHoldRefCount, map_var_info_t Name = nullptr,
-                     bool IsINF = false)
+                     bool UseHoldRefCount, int32_t AllocKind,
+                     map_var_info_t Name = nullptr, bool IsINF = false)
       : HstPtrBase(BP), HstPtrBegin(B), HstPtrEnd(E), HstPtrName(Name),
-        TgtAllocBegin(TgtAllocBegin), TgtPtrBegin(TgtPtrBegin),
+        AllocKind(AllocKind), TgtAllocBegin(TgtAllocBegin),
+        TgtPtrBegin(TgtPtrBegin),
         States(std::make_unique<StatesTy>(UseHoldRefCount ? 0
                                           : IsINF         ? INFRefCount
                                                           : 1,
@@ -479,11 +482,11 @@ struct MappingInfoTy {
   /// - Data transfer issue fails.
   TargetPointerResultTy getTargetPointer(
       HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, void *HstPtrBase,
-      int64_t TgtPadding, int64_t Size, map_var_info_t HstPtrName,
-      bool HasFlagTo, bool HasFlagAlways, bool IsImplicit, bool UpdateRefCount,
-      bool HasCloseModifier, bool HasPresentModifier, bool HasHoldModifier,
-      AsyncInfoTy &AsyncInfo, HostDataToTargetTy *OwnedTPR = nullptr,
-      bool ReleaseHDTTMap = true);
+      int64_t TgtPadding, int64_t Size, int64_t TypeFlags,
+      map_var_info_t HstPtrName, bool HasFlagTo, bool HasFlagAlways,
+      bool IsImplicit, bool UpdateRefCount, bool HasCloseModifier,
+      bool HasPresentModifier, bool HasHoldModifier, AsyncInfoTy &AsyncInfo,
+      HostDataToTargetTy *OwnedTPR = nullptr, bool ReleaseHDTTMap = true);
 
   /// Return the target pointer for \p HstPtrBegin in \p HDTTMap. The accessor
   /// ensures exclusive access to the HDTT map.
diff --git a/offload/include/omptarget.h b/offload/include/omptarget.h
index 6971780c7bdb5..8c711773067a2 100644
--- a/offload/include/omptarget.h
+++ b/offload/include/omptarget.h
@@ -80,6 +80,10 @@ enum tgt_map_type {
   // the structured region
   // This is an OpenMP extension for the sake of OpenACC support.
   OMP_TGT_MAPTYPE_OMPX_HOLD       = 0x2000,
+  // mapping is for a descriptor (a.k.a. dope vector)
+  OMP_TGT_MAPTYPE_DESCRIPTOR      = 0x4000,
+  // Mapping is for a descriptor's (a.k.a. dope vector) data base address
+  OMP_TGT_MAPTYPE_DESCRIPTOR_BASE_ADDR = 0x8000,
   // descriptor for non-contiguous target-update
   OMP_TGT_MAPTYPE_NON_CONTIG      = 0x100000000000,
   // member of struct, member given by [16 MSBs] - 1
diff --git a/offload/libomptarget/OpenMP/Mapping.cpp b/offload/libomptarget/OpenMP/Mapping.cpp
index 14f5e7dc9d19f..8620f5f2fafe8 100644
--- a/offload/libomptarget/OpenMP/Mapping.cpp
+++ b/offload/libomptarget/OpenMP/Mapping.cpp
@@ -77,7 +77,8 @@ int MappingInfoTy::associatePtr(void *HstPtrBegin, void *TgtPtrBegin,
                /*HstPtrEnd=*/(uintptr_t)HstPtrBegin + Size,
                /*TgtAllocBegin=*/(uintptr_t)TgtPtrBegin,
                /*TgtPtrBegin=*/(uintptr_t)TgtPtrBegin,
-               /*UseHoldRefCount=*/false, /*Name=*/nullptr,
+               /*UseHoldRefCount=*/false, /*AllocKind=*/TARGET_ALLOC_DEFAULT,
+               /*Name=*/nullptr,
                /*IsRefCountINF=*/true))
            .first->HDTT;
   DP("Creating new map entry: HstBase=" DPxMOD ", HstBegin=" DPxMOD
@@ -199,10 +200,11 @@ LookupResult MappingInfoTy::lookupMapping(HDTTMapAccessorTy &HDTTMap,
 
 TargetPointerResultTy MappingInfoTy::getTargetPointer(
     HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, void *HstPtrBase,
-    int64_t TgtPadding, int64_t Size, map_var_info_t HstPtrName, bool HasFlagTo,
-    bool HasFlagAlways, bool IsImplicit, bool UpdateRefCount,
-    bool HasCloseModifier, bool HasPresentModifier, bool HasHoldModifier,
-    AsyncInfoTy &AsyncInfo, HostDataToTargetTy *OwnedTPR, bool ReleaseHDTTMap) {
+    int64_t TgtPadding, int64_t Size, int64_t TypeFlags,
+    map_var_info_t HstPtrName, bool HasFlagTo, bool HasFlagAlways,
+    bool IsImplicit, bool UpdateRefCount, bool HasCloseModifier,
+    bool HasPresentModifier, bool HasHoldModifier, AsyncInfoTy &AsyncInfo,
+    HostDataToTargetTy *OwnedTPR, bool ReleaseHDTTMap) {
 
   LookupResult LR = lookupMapping(HDTTMap, HstPtrBegin, Size, OwnedTPR);
   LR.TPR.Flags.IsPresent = true;
@@ -286,17 +288,28 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer(
   } else if (Size) {
     // If it is not contained and Size > 0, we should create a new entry for it.
     LR.TPR.Flags.IsNewEntry = true;
+
+    int32_t AllocKind = TARGET_ALLOC_DEFAULT;
+
+    if (TypeFlags == OMP_TGT_MAPTYPE_DESCRIPTOR &&
+        Device.RTL->use_shared_mem_for_descriptor(Device.DeviceID, Size)) {
+      AllocKind = TARGET_ALLOC_SHARED;
+      INFO(OMP_INFOTYPE_MAPPING_CHANGED, Device.DeviceID,
+           "Using shared memory for descriptor allocation of size=%zu\n", Size);
+    }
+
     uintptr_t TgtAllocBegin =
-        (uintptr_t)Device.allocData(TgtPadding + Size, HstPtrBegin);
+        (uintptr_t)Device.allocData(TgtPadding + Size, HstPtrBegin, AllocKind);
     uintptr_t TgtPtrBegin = TgtAllocBegin + TgtPadding;
     // Release the mapping table lock only after the entry is locked by
     // attaching it to TPR.
-    LR.TPR.setEntry(HDTTMap
-                        ->emplace(new HostDataToTargetTy(
-                            (uintptr_t)HstPtrBase, (uintptr_t)HstPtrBegin,
-                            (uintptr_t)HstPtrBegin + Size, TgtAllocBegin,
-                            TgtPtrBegin, HasHoldModifier, HstPtrName))
-                        .first->HDTT);
+    LR.TPR.setEntry(
+        HDTTMap
+            ->emplace(new HostDataToTargetTy(
+                (uintptr_t)HstPtrBase, (uintptr_t)HstPtrBegin,
+                (uintptr_t)HstPtrBegin + Size, TgtAllocBegin, TgtPtrBegin,
+                HasHoldModifier, AllocKind, HstPtrName))
+            .first->HDTT);
     INFO(OMP_INFOTYPE_MAPPING_CHANGED, Device.DeviceID,
          "Creating new map entry with HstPtrBase=" DPxMOD
          ", HstPtrBegin=" DPxMOD ", TgtAllocBegin=" DPxMOD
@@ -326,20 +339,47 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer(
   // data transfer.
   if (LR.TPR.TargetPointer && !LR.TPR.Flags.IsHostPointer && HasFlagTo &&
       (LR.TPR.Flags.IsNewEntry || HasFlagAlways) && Size != 0) {
-    DP("Moving %" PRId64 " bytes (hst:" DPxMOD ") -> (tgt:" DPxMOD ")\n", Size,
-       DPxPTR(HstPtrBegin), DPxPTR(LR.TPR.TargetPointer));
-
-    int Ret = Device.submitData(LR.TPR.TargetPointer, HstPtrBegin, Size,
-                                AsyncInfo, LR.TPR.getEntry());
-    if (Ret != OFFLOAD_SUCCESS) {
-      REPORT("Copying data to device failed.\n");
-      // We will also return nullptr if the data movement fails because that
-      // pointer points to a corrupted memory region so it doesn't make any
-      // sense to continue to use it.
-      LR.TPR.TargetPointer = nullptr;
-    } else if (LR.TPR.getEntry()->addEventIfNecessary(Device, AsyncInfo) !=
-               OFFLOAD_SUCCESS)
-      return TargetPointerResultTy{};
+    if (LR.TPR.Flags.IsNewEntry ||
+        LR.TPR.getEntry()->AllocKind != TARGET_ALLOC_SHARED) {
+      DP("Moving %" PRId64 " bytes (hst:" DPxMOD ") -> (tgt:" DPxMOD ")\n",
+         Size, DPxPTR(HstPtrBegin), DPxPTR(LR.TPR.TargetPointer));
+
+      // If we are mapping a descriptor/dope vector, we map it with always as
+      // this information should always be up-to-date. Another issue is that
+      // due to an edge-case with declare target preventing this information
+      // being initialized on device we have force initialize it with always.
+      // However, in these cases to prevent overwriting of the data pointer
+      // breaking any pointer <-> data attachment that previous mappings may
+      // have established, we skip over the data pointer stored in the dope
+      // vector/descriptor, a subsequent seperate mapping of the pointer and
+      // data by the compiler should correctly establish any required data
+      // mappings, the descriptor mapping primarily just populates the relevant
+      // descriptor data fields that the fortran runtime depends on for bounds
+      // calculation and other relating things. The pointer is always in the
+      // same place, the first field of the descriptor structure, so we skip
+      // it by offsetting by 8-bytes. On architectures with more varied pointer
+      // sizes this may need further thought, but so would a lot of the data
+      // mapping I imagine if host/device pointers are mismatched sizes.
+      if ((TypeFlags & OMP_TGT_MAPTYPE_DESCRIPTOR) && HasFlagAlways) {
+        uintptr_t DescDataPtrOffset = 8;
+        HstPtrBegin = (void *)((uintptr_t)HstPtrBegin + DescDataPtrOffset);
+        LR.TPR.TargetPointer =
+            (void *)((uintptr_t)LR.TPR.TargetPointer + DescDataPtrOffset);
+        Size = Size - DescDataPtrOffset;
+      }
+
+      int Ret = Device.submitData(LR.TPR.TargetPointer, HstPtrBegin, Size,
+                                  AsyncInfo, LR.TPR.getEntry());
+      if (Ret != OFFLOAD_SUCCESS) {
+        REPORT("Copying data to device failed.\n");
+        // We will also return nullptr if the data movement fails because that
+        // pointer points to a corrupted memory region so it doesn't make any
+        // sense to continue to use it.
+        LR.TPR.TargetPointer = nullptr;
+      } else if (LR.TPR.getEntry()->addEventIfNecessary(Device, AsyncInfo) !=
+                 OFFLOAD_SUCCESS)
+        return TargetPointerResultTy{};
+    }
   } else {
     // If not a host pointer and no present modifier, we need to wait for the
     // event if it exists.
diff --git a/offload/libomptarget/PluginManager.cpp b/offload/libomptarget/PluginManager.cpp
index d6d529a207587..e508b67bf38bb 100644
--- a/offload/libomptarget/PluginManager.cpp
+++ b/offload/libomptarget/PluginManager.cpp
@@ -497,8 +497,8 @@ static int loadImagesOntoDevice(DeviceTy &Device) {
                 CurrHostEntry->Size /*HstPtrEnd*/,
             (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
             (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
-            false /*UseHoldRefCount*/, CurrHostEntry->SymbolName,
-            true /*IsRefCountINF*/));
+            false /*UseHoldRefCount*/, TARGET_ALLOC_DEFAULT /*AllocKind*/,
+            CurrHostEntry->SymbolName, true /*IsRefCountINF*/));
 
         // Notify about the new mapping.
         if (Device.notifyDataMapped(CurrHostEntry->Address,
diff --git a/offload/libomptarget/omptarget.cpp b/offload/libomptarget/omptarget.cpp
index 5b25d955dd320..f900f17708b60 100644
--- a/offload/libomptarget/omptarget.cpp
+++ b/offload/libomptarget/omptarget.cpp
@@ -422,7 +422,7 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum,
       // when HasPresentModifier.
       PointerTpr = Device.getMappingInfo().getTargetPointer(
           HDTTMap, HstPtrBase, HstPtrBase, /*TgtPadding=*/0, sizeof(void *),
-          /*HstPtrName=*/nullptr,
+          ArgTypes[I], /*HstPtrName=*/nullptr,
           /*HasFlagTo=*/false, /*HasFlagAlways=*/false, IsImplicit, UpdateRef,
           HasCloseModifier, HasPresentModifier, HasHoldModifier, AsyncInfo,
           /*OwnedTPR=*/nullptr, /*ReleaseHDTTMap=*/false);
@@ -451,9 +451,10 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum,
     const bool HasFlagAlways = ArgTypes[I] & OMP_TGT_MAPTYPE_ALWAYS;
     // Note that HDTTMap will be released in getTargetPointer.
     auto TPR = Device.getMappingInfo().getTargetPointer(
-        HDTTMap, HstPtrBegin, HstPtrBase, TgtPadding, DataSize, HstPtrName,
-        HasFlagTo, HasFlagAlways, IsImplicit, UpdateRef, HasCloseModifier,
-        HasPresentModifier, HasHoldModifier, AsyncInfo, PointerTpr.getEntry());
+        HDTTMap, HstPtrBegin, HstPtrBase, TgtPadding, DataSize, ArgTypes[I],
+        HstPtrName, HasFlagTo, HasFlagAlways, IsImplicit, UpdateRef,
+        HasCloseModifier, HasPresentModifier, HasHoldModifier, AsyncInfo,
+        PointerTpr.getEntry());
     void *TgtPtrBegin = TPR.TargetPointer;
     IsHostPtr = TPR.Flags.IsHostPointer;
     // If data_size==0, then the argument could be a zero-length pointer to
@@ -482,7 +483,9 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum,
 
       if (PointerTpr.getEntry()->addShadowPointer(ShadowPtrInfoTy{
               (void **)PointerHstPtrBegin, HstPtrBase,
-              (void **)PointerTgtPtrBegin, ExpectedTgtPtrBase})) {
+              (void **)PointerTgtPtrBegin, ExpectedTgtPtrBase,
+              static_cast<bool>(ArgTypes[I] &
+                                OMP_TGT_MAPTYPE_DESCRIPTOR_BASE_ADDR)})) {
         DP("Update pointer (" DPxMOD ") -> [" DPxMOD "]\n",
            DPxPTR(PointerTgtPtrBegin), DPxPTR(TgtPtrBegin));
 
@@ -591,6 +594,12 @@ postProcessingTargetDataEnd(DeviceTy *Device,
     const bool HasFrom = ArgType & OMP_TGT_MAPTYPE_FROM;
     if (HasFrom) {
       Entry->foreachShadowPointerInfo([&](const ShadowPtrInfoTy &ShadowPtr) {
+        // For Fortran descriptors/dope vectors, it is possible, we have
+        // deallocated the data on host and the descriptor persists as it is
+        // a separate entity, and we do not want to map back the data to host
+        // in these cases when releasing the dope vector.
+        if (*ShadowPtr.HstPtrAddr == nullptr && ShadowPtr.IsDescriptorBaseAddr)
+          return OFFLOAD_SUCCESS;
         *ShadowPtr.HstPtrAddr = ShadowPtr.HstPtrVal;
         DP("Restoring original host pointer value " DPxMOD " for host "
            "pointer " DPxMOD "\n",
@@ -833,6 +842,13 @@ static int targetDataContiguous(ident_t *Loc, DeviceTy &Device, void *ArgsBase,
       AsyncInfo.addPostProcessingFunction([=]() -> int {
         int Ret = Entry->foreachShadowPointerInfo(
             [&](const ShadowPtrInfoTy &ShadowPtr) {
+              // For Fortran descriptors/dope vectors, it is possible, we have
+              // deallocated the data on host and the descriptor persists as it
+              // is a separate entity, and we do not want to map back the data
+              // to host in these cases when releasing the dope vector.
+              if (*ShadowPtr.HstPtrAddr == nullptr &&
+                  ShadowPtr.IsDescriptorBaseAddr)
+                return OFFLOAD_SUCCESS;
               *ShadowPtr.HstPtrAddr = ShadowPtr.HstPtrVal;
               DP("Restoring original host pointer value " DPxMOD
                  " for host pointer " DPxMOD "\n",
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index ed575f2213f28..0be756a5d5d70 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -1906,9 +1906,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
         OMPX_StreamBusyWait("LIBOMPTARGET_AMDGPU_STREAM_BUSYWAIT", 2000000),
         OMPX_UseMultipleSdmaEngines(
             "LIBOMPTARGET_AMDGPU_USE_MULTIPLE_SDMA_ENGINES", false),
-        OMPX_ApuMaps("OMPX_APU_MAPS", false), AMDGPUStreamManager(*this, Agent),
-        AMDGPUEventManager(*this), AMDGPUSignalManager(*this), Agent(Agent),
-        HostDevice(HostDevice) {}
+        OMPX_ApuMaps("OMPX_APU_MAPS", false),
+        OMPX_SharedDescriptorMaxSize("LIBOMPTARGET_SHARED_DESCRIPTOR_MAX_SIZE",
+                                     96),
+        AMDGPUStreamManager(*this, Agent), AMDGPUEventManager(*this),
+        AMDGPUSignalManager(*this), Agent(Agent), HostDevice(HostDevice) {}
 
   ~AMDGPUDeviceTy() {}
 
@@ -2813,6 +2815,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
 
   bool useMultipleSdmaEngines() const { return OMPX_UseMultipleSdmaEngines; }
 
+  bool useSharedMemForDescriptor(int64_t Size) override {
+    return Size <= OMPX_SharedDescriptorMaxSize;
+  }
+
 private:
   using AMDGPUEventRef = AMDGPUResourceRef<AMDGPUEventTy>;
   using AMDGPUEventManagerTy = GenericDeviceResourceManagerTy<AMDGPUEventRef>;
@@ -2911,6 +2917,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   /// automatic zero-copy behavior on non-APU GPUs.
   BoolEnvar OMPX_ApuMaps;
 
+  /// Descriptors of size <= to this value will be allocated using shared
+  /// memory. Default value is 48.
+  UInt32Envar OMPX_SharedDescriptorMaxSize;
+
   /// Stream manager for AMDGPU streams.
   AMDGPUStreamManagerTy AMDGPUStreamManager;
 
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index e54a8afdd3f4f..0fcb7504c09e6 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -942,6 +942,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   /// Allocate and construct a kernel object.
   virtual Expected<GenericKernelTy &> constructKernel(const char *Name) = 0;
 
+  /// Return true if a descriptor of size 'Size' should be allocated using
+  /// shared memory. Default implementation returns 'false',
+  virtual bool useSharedMemForDescriptor(int64_t Size);
+
   /// Reference to the underlying plugin that created this device.
   GenericPluginTy &Plugin;
 
@@ -1344,6 +1348,10 @@ struct GenericPluginTy {
   int32_t get_function(__tgt_device_binary Binary, const char *Name,
                        void **KernelPtr);
 
+  /// Return true if a descriptor of size 'Size' should be allocated using
+  /// shared memory.
+  bool use_shared_mem_for_descriptor(int32_t DeviceId, int64_t Size);
+
 private:
   /// Indicates if the platform runtime has been fully initialized.
   bool Initialized = false;
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 059f14f59c38b..9cf0af6065bd6 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -1597,6 +1597,8 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) {
 
 bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); }
 
+bool GenericDeviceTy::useSharedMemForDescriptor(int64_t Size) { return false; }
+
 Error GenericPluginTy::init() {
   if (Initialized)
     return Plugin::success();
@@ -2199,3 +2201,8 @@ int32_t GenericPluginTy::get_function(__tgt_device_binary Binary,
   *KernelPtr = &Kernel;
   return OFFLOAD_SUCCESS;
 }
+
+bool GenericPluginTy::use_shared_mem_for_descriptor(int32_t DeviceId,
+                                                    int64_t Size) {
+  return getDevice(DeviceId).useSharedMemForDescriptor(Size);
+}

@agozillon
Copy link
Contributor Author

Not an area of the project I usually open PR's against, so please feel free to tag on additional relevant reviewers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants