Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit aa78e94

Browse files
committed
[Libomptarget] Support mapping indirect host calls to device functions
The changes in D157738 allowed for us to emit stub globals on the device in the offloading entry section. These globals contain the addresses of device functions and allow us to map host functions to their corresponding device equivalent. This patch provides the initial support required to build a table on the device to lookup the associated value. This is done by finding these entries and creating a global table on the device that can be searched with a simple binary search. This requires an allocation, which supposedly should be automatically freed at plugin shutdown. This includes a basic test which looks up device pointers via a host pointer using the added function. This will need to be built upon to provide full support for these calls in the runtime. To support reverse offloading it would also be useful to provide a reverse table that allows us to get host functions from device stubs. Depends on D157738 Reviewed By: jdoerfert Differential Revision: https://reviews.llvm.org/D157918
1 parent 79cf24e commit aa78e94

9 files changed

Lines changed: 176 additions & 5 deletions

File tree

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ class OffloadEntriesInfoManager {
327327
/// Mark the entry as having no declare target entry kind.
328328
OMPTargetGlobalVarEntryNone = 0x3,
329329
/// Mark the entry as a declare target indirect global.
330-
OMPTargetGlobalVarEntryIndirect = 0x4,
330+
OMPTargetGlobalVarEntryIndirect = 0x8,
331331
};
332332

333333
/// Kind of device clause for declare target variables

openmp/libomptarget/DeviceRTL/include/Configuration.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ uint64_t getDynamicMemorySize();
4040
/// Returns the cycles per second of the device's fixed frequency clock.
4141
uint64_t getClockFrequency();
4242

43+
/// Returns the pointer to the beginning of the indirect call table.
44+
void *getIndirectCallTablePtr();
45+
46+
/// Returns the size of the indirect call table.
47+
uint64_t getIndirectCallTableSize();
48+
4349
/// Return if debugging is enabled for the given debug kind.
4450
bool isDebugMode(DebugKind Level);
4551

openmp/libomptarget/DeviceRTL/src/Configuration.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ uint64_t config::getClockFrequency() {
5050
return __omp_rtl_device_environment.ClockFrequency;
5151
}
5252

53+
void *config::getIndirectCallTablePtr() {
54+
return reinterpret_cast<void *>(
55+
__omp_rtl_device_environment.IndirectCallTable);
56+
}
57+
58+
uint64_t config::getIndirectCallTableSize() {
59+
return __omp_rtl_device_environment.IndirectCallTableSize;
60+
}
61+
5362
bool config::isDebugMode(config::DebugKind Kind) {
5463
return config::getDebugKind() & Kind;
5564
}

openmp/libomptarget/DeviceRTL/src/Misc.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,47 @@ double getWTime() {
6969

7070
#pragma omp end declare variant
7171

72+
/// Lookup a device-side function using a host pointer /p HstPtr using the table
73+
/// provided by the device plugin. The table is an ordered pair of host and
74+
/// device pointers sorted on the value of the host pointer.
75+
void *indirectCallLookup(void *HstPtr) {
76+
if (!HstPtr)
77+
return nullptr;
78+
79+
struct IndirectCallTable {
80+
void *HstPtr;
81+
void *DevPtr;
82+
};
83+
IndirectCallTable *Table =
84+
reinterpret_cast<IndirectCallTable *>(config::getIndirectCallTablePtr());
85+
uint64_t TableSize = config::getIndirectCallTableSize();
86+
87+
// If the table is empty we assume this is device pointer.
88+
if (!Table || !TableSize)
89+
return HstPtr;
90+
91+
uint32_t Left = 0;
92+
uint32_t Right = TableSize;
93+
94+
// If the pointer is definitely not contained in the table we exit early.
95+
if (HstPtr < Table[Left].HstPtr || HstPtr > Table[Right - 1].HstPtr)
96+
return HstPtr;
97+
98+
while (Left != Right) {
99+
uint32_t Current = Left + (Right - Left) / 2;
100+
if (Table[Current].HstPtr == HstPtr)
101+
return Table[Current].DevPtr;
102+
103+
if (HstPtr < Table[Current].HstPtr)
104+
Right = Current;
105+
else
106+
Left = Current;
107+
}
108+
109+
// If we searched the whole table and found nothing this is a device pointer.
110+
return HstPtr;
111+
}
112+
72113
} // namespace impl
73114
} // namespace ompx
74115

@@ -84,6 +125,10 @@ int32_t __kmpc_cancel(IdentTy *, int32_t, int32_t) { return 0; }
84125
double omp_get_wtick(void) { return ompx::impl::getWTick(); }
85126

86127
double omp_get_wtime(void) { return ompx::impl::getWTime(); }
128+
129+
void *__llvm_omp_indirect_call_lookup(void *HstPtr) {
130+
return ompx::impl::indirectCallLookup(HstPtr);
131+
}
87132
}
88133

89134
///}

openmp/libomptarget/include/Environment.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct DeviceEnvironmentTy {
3131
uint32_t DeviceNum;
3232
uint32_t DynamicMemSize;
3333
uint64_t ClockFrequency;
34+
uintptr_t IndirectCallTable;
35+
uint64_t IndirectCallTableSize;
3436
};
3537

3638
// NOTE: Please don't change the order of those members as their indices are

openmp/libomptarget/include/omptarget.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,16 @@ enum tgt_map_type {
8383
OMP_TGT_MAPTYPE_MEMBER_OF = 0xffff000000000000
8484
};
8585

86+
/// Flags for offload entries.
8687
enum OpenMPOffloadingDeclareTargetFlags {
87-
/// Mark the entry as having a 'link' attribute.
88+
/// Mark the entry global as having a 'link' attribute.
8889
OMP_DECLARE_TARGET_LINK = 0x01,
89-
/// Mark the entry as being a global constructor.
90+
/// Mark the entry kernel as being a global constructor.
9091
OMP_DECLARE_TARGET_CTOR = 0x02,
91-
/// Mark the entry as being a global destructor.
92-
OMP_DECLARE_TARGET_DTOR = 0x04
92+
/// Mark the entry kernel as being a global destructor.
93+
OMP_DECLARE_TARGET_DTOR = 0x04,
94+
/// Mark the entry global as being an indirectly callable function.
95+
OMP_DECLARE_TARGET_INDIRECT = 0x08
9396
};
9497

9598
enum OpenMPOffloadingRequiresDirFlags {

openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,53 @@ struct RecordReplayTy {
267267

268268
} RecordReplay;
269269

270+
// Extract the mapping of host function pointers to device function pointers
271+
// from the entry table. Functions marked as 'indirect' in OpenMP will have
272+
// offloading entries generated for them which map the host's function pointer
273+
// to a global containing the corresponding function pointer on the device.
274+
static Expected<std::pair<void *, uint64_t>>
275+
setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
276+
DeviceImageTy &Image) {
277+
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
278+
279+
llvm::ArrayRef<__tgt_offload_entry> Entries(Image.getTgtImage()->EntriesBegin,
280+
Image.getTgtImage()->EntriesEnd);
281+
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
282+
for (const auto &Entry : Entries) {
283+
if (Entry.size == 0 || !(Entry.flags & OMP_DECLARE_TARGET_INDIRECT))
284+
continue;
285+
286+
assert(Entry.size == sizeof(void *) && "Global not a function pointer?");
287+
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
288+
289+
GlobalTy DeviceGlobal(Entry.name, Entry.size);
290+
if (auto Err =
291+
Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
292+
return std::move(Err);
293+
294+
HstPtr = Entry.addr;
295+
if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
296+
Entry.size, nullptr))
297+
return std::move(Err);
298+
}
299+
300+
// If we do not have any indirect globals we exit early.
301+
if (IndirectCallTable.empty())
302+
return std::pair{nullptr, 0};
303+
304+
// Sort the array to allow for more efficient lookup of device pointers.
305+
llvm::sort(IndirectCallTable,
306+
[](const auto &x, const auto &y) { return x.first < y.first; });
307+
308+
uint64_t TableSize =
309+
IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
310+
void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
311+
if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
312+
TableSize, nullptr))
313+
return std::move(Err);
314+
return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
315+
}
316+
270317
AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
271318
__tgt_async_info *AsyncInfoPtr)
272319
: Device(Device),
@@ -626,13 +673,21 @@ Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
626673
if (!shouldSetupDeviceEnvironment())
627674
return Plugin::success();
628675

676+
// Obtain a table mapping host function pointers to device function pointers.
677+
auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
678+
if (!CallTablePairOrErr)
679+
return CallTablePairOrErr.takeError();
680+
629681
DeviceEnvironmentTy DeviceEnvironment;
630682
DeviceEnvironment.DebugKind = OMPX_DebugKind;
631683
DeviceEnvironment.NumDevices = Plugin.getNumDevices();
632684
// TODO: The device ID used here is not the real device ID used by OpenMP.
633685
DeviceEnvironment.DeviceNum = DeviceId;
634686
DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
635687
DeviceEnvironment.ClockFrequency = getClockFrequency();
688+
DeviceEnvironment.IndirectCallTable =
689+
reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
690+
DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
636691

637692
// Create the metainfo of the device environment global.
638693
GlobalTy DevEnvGlobal("__omp_rtl_device_environment",

openmp/libomptarget/src/rtl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ static void registerGlobalCtorsDtorsForImage(__tgt_bin_desc *Desc,
303303
Device.HasPendingGlobals = true;
304304
for (__tgt_offload_entry *Entry = Img->EntriesBegin;
305305
Entry != Img->EntriesEnd; ++Entry) {
306+
// Globals are not callable and use a different set of flags.
307+
if (Entry->size != 0)
308+
continue;
309+
306310
if (Entry->flags & OMP_DECLARE_TARGET_CTOR) {
307311
DP("Adding ctor " DPxMOD " to the pending list.\n",
308312
DPxPTR(Entry->addr));
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: %libomptarget-compile-run-and-check-generic
2+
3+
#include <assert.h>
4+
#include <stdio.h>
5+
6+
#pragma omp begin declare variant match(device = {kind(gpu)})
7+
// Provided by the runtime.
8+
void *__llvm_omp_indirect_call_lookup(void *host_ptr);
9+
#pragma omp declare target to(__llvm_omp_indirect_call_lookup) \
10+
device_type(nohost)
11+
#pragma omp end declare variant
12+
13+
#pragma omp begin declare variant match(device = {kind(cpu)})
14+
// We assume unified addressing on the CPU target.
15+
void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; }
16+
#pragma omp end declare variant
17+
18+
#pragma omp begin declare target indirect
19+
void foo(int *x) { *x = *x + 1; }
20+
void bar(int *x) { *x = *x + 1; }
21+
void baz(int *x) { *x = *x + 1; }
22+
#pragma omp end declare target
23+
24+
int main() {
25+
void *foo_ptr = foo;
26+
void *bar_ptr = bar;
27+
void *baz_ptr = baz;
28+
29+
int count = 0;
30+
void *foo_res;
31+
void *bar_res;
32+
void *baz_res;
33+
#pragma omp target map(to : foo_ptr, bar_ptr, baz_ptr) map(tofrom : count)
34+
{
35+
foo_res = __llvm_omp_indirect_call_lookup(foo_ptr);
36+
((void (*)(int *))foo_res)(&count);
37+
bar_res = __llvm_omp_indirect_call_lookup(bar_ptr);
38+
((void (*)(int *))bar_res)(&count);
39+
baz_res = __llvm_omp_indirect_call_lookup(baz_ptr);
40+
((void (*)(int *))baz_res)(&count);
41+
}
42+
43+
assert(count == 3 && "Calling failed");
44+
45+
// CHECK: PASS
46+
printf("PASS\n");
47+
}

0 commit comments

Comments
 (0)