Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a4eb97f

Browse files
committed
Bug fixes
1 parent 40a0e02 commit a4eb97f

3 files changed

Lines changed: 148 additions & 70 deletions

File tree

apex/contrib/csrc/peer_memory/peer_memory_cuda.cu

Lines changed: 97 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,23 @@ void deleter(void* ptr)
3030
*/
3131

3232
template<class T>
33-
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options)
33+
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
3434
{
35-
std::vector<int64_t> strides(shape.size());
3635
size_t size = 1;
37-
int idx = strides.size();
38-
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
39-
{
40-
strides[--idx] = size;
41-
size *= *it;
36+
std::vector<int64_t> strides(shape.size());
37+
if (channels_last) {
38+
assert(shape.size() == 4);
39+
strides[0] = shape[1]*shape[2]*shape[3];
40+
strides[1] = 1;
41+
strides[2] = shape[1]*shape[3];
42+
strides[3] = shape[1];
43+
} else {
44+
int idx = strides.size();
45+
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
46+
{
47+
strides[--idx] = size;
48+
size *= *it;
49+
}
4250
}
4351
size *= sizeof(T);
4452
// TODO: Implement dynamic reuse of pooled peer memory.
@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
139147
}
140148
}
141149

150+
template<bool wait, bool clear>
142151
__device__ void dual_signal_wait_clear(
143152
volatile int* signal1_flag, volatile int* wait1_flag,
144153
volatile int* signal2_flag, volatile int* wait2_flag,
145-
const int v1, const int v2, const int v3, const int v4,
146-
const bool clear
154+
const int v1, const int v2, const int v3, const int v4
147155
)
148156
{
149157
register int r1, r2, r3, r4, r5, r6, r7, r8;
@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear(
152160
if (is_main_thread) {
153161
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
154162
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
155-
do {
156-
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory");
157-
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory");
158-
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
163+
if (wait) {
164+
do {
165+
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory");
166+
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory");
167+
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
168+
}
159169
}
160170
cg::this_grid().sync();
161-
// optionally clear wait flag
162-
if (clear && is_main_thread) {
163-
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
164-
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
165-
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
171+
if (clear) {
172+
if (is_main_thread) {
173+
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
174+
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
175+
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
176+
}
166177
}
167178
}
168179

@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
173184
__global__ void push_pull_halos_1d_kernel(
174185
// top halo,
175186
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
176-
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top tx buffer
187+
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
188+
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
177189
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
178190
// btm halo
179-
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // top output halo
180-
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // top tx buffer
181-
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // top input halo
191+
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
192+
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
193+
T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
194+
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
182195
// dimensions
183196
int NC, int NH, int NW,
184197
// signals
@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
194207
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
195208
// signal to top and btm neigbhbors that output halos are ready to be read
196209
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
197-
dual_signal_wait_clear(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358, true);
210+
dual_signal_wait_clear<true,true>(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358);
198211
// pull top halo from transfer buffer in peer memory to input
199-
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, tih, tih_stride_C, tih_stride_H, tih_stride_W, NC, NH, NW);
212+
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
200213
// pull btm halo from transfer buffer in peer memory to input
201-
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, bih, bih_stride_C, bih_stride_H, bih_stride_W, NC, NH, NW);
214+
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
202215
}
203216

204217
}
@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
246259
return results;
247260
}
248261

249-
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape)
262+
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
250263
{
251-
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA));
264+
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
252265
}
253266

254-
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape)
267+
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
255268
{
256-
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA));
269+
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
257270
}
258271

259-
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape)
272+
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
260273
{
261-
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA));
274+
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
262275
}
263276

264277
void push_pull_halos_1d(
278+
bool diagnostics,
265279
bool explicit_nhwc,
266280
int numSM, // number of SMs to use
267281
at::Tensor top_out_halo, // top output halo in sender device memory
268282
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
283+
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
269284
at::Tensor top_inp_halo, // top input halo in receiver device memory
270285
at::Tensor btm_out_halo, // btm output halo in sender device memory
271286
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
287+
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
272288
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
273289
at::Tensor top_signal, // top input signal in receiver device memory
274290
at::Tensor btm_signal, // btm input signal in receiver device memory
@@ -278,9 +294,11 @@ void push_pull_halos_1d(
278294
// basic checks of inputs
279295
TORCH_CHECK(top_out_halo.is_cuda());
280296
TORCH_CHECK(top_out_tx.is_cuda());
297+
TORCH_CHECK(top_inp_tx.is_cuda());
281298
TORCH_CHECK(top_inp_halo.is_cuda());
282299
TORCH_CHECK(btm_out_halo.is_cuda());
283300
TORCH_CHECK(btm_out_tx.is_cuda());
301+
TORCH_CHECK(btm_inp_tx.is_cuda());
284302
TORCH_CHECK(btm_inp_halo.is_cuda());
285303
TORCH_CHECK(top_signal.is_cuda());
286304
TORCH_CHECK(btm_signal.is_cuda());
@@ -291,46 +309,56 @@ void push_pull_halos_1d(
291309
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
292310
int tox_N, tox_C, tox_H, tox_W;
293311
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
312+
int tix_N, tix_C, tix_H, tix_W;
313+
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
294314
int tih_N, tih_C, tih_H, tih_W;
295315
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
296316
TORCH_CHECK(
297-
(toh_N == tox_N && tox_N == tih_N) &&
298-
(toh_C == tox_C && tox_C == tih_C) &&
299-
(toh_H == tox_H && tox_H == tih_H) &&
300-
(toh_W == tox_W && tox_W == tih_W));
317+
(toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
318+
(toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
319+
(toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
320+
(toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
301321
int boh_N, boh_C, boh_H, boh_W;
302322
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
303323
int box_N, box_C, box_H, box_W;
304324
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
325+
int bix_N, bix_C, bix_H, bix_W;
326+
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
305327
int bih_N, bih_C, bih_H, bih_W;
306328
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
307329
TORCH_CHECK(
308-
(boh_N == box_N && box_N == bih_N) &&
309-
(boh_C == box_C && box_C == bih_C) &&
310-
(boh_H == box_H && box_H == bih_H) &&
311-
(boh_W == box_W && box_W == bih_W));
330+
(boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
331+
(boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
332+
(boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
333+
(boh_W == box_W && box_W == bix_W && bix_W == bih_W));
312334
TORCH_CHECK(
313335
(toh_N == boh_N) &&
314336
(toh_C == boh_C) &&
315337
(toh_H == boh_H) &&
316338
(toh_W == boh_W));
317339
int NC=toh_C, NH=toh_H, NW=toh_W;
340+
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
318341

319342
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
320343
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
321344
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
322345
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
346+
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
347+
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
323348
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
324349
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
325350
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
326351
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
327352
int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
328353
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
354+
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
355+
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
329356
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
330357
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
331358

332359
// determine if nhwc
333360
auto is_nhwc = (toh_stride_C == 1) ? true : false;
361+
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
334362

335363
// figure out launch parameters
336364
int device;
@@ -342,35 +370,59 @@ void push_pull_halos_1d(
342370
const int numThreads = 128;
343371
dim3 block(numThreads,1,1);
344372
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
373+
if (diagnostics) printf("size(scalar_t) = %d\n",sizeof(scalar_t));
345374
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
346375
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
376+
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
347377
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
348378
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
349379
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
380+
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
350381
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
351-
int* top_signal_p = top_signal.data_ptr<int>();
352-
int* btm_signal_p = btm_signal.data_ptr<int>() + 4;
382+
if (diagnostics) printf("waypoint1\n");
383+
int* top_signal_p = top_signal.data_ptr<int>() + 4;
384+
int* btm_signal_p = btm_signal.data_ptr<int>();
353385
int* top_wait_p = waits.data_ptr<int>();
354386
int* btm_wait_p = waits.data_ptr<int>() + 4;
387+
if (diagnostics) printf("waypoint2\n");
355388

356389
// do int4 vector loads if channel count permits
357390
int elem_size_in_bytes = toh_C * sizeof(scalar_t);
358391
int elem_size_in_int4 = (elem_size_in_bytes / 16);
392+
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
359393
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
360394
// can do int4 transfers
361-
int divisor = elem_size_in_bytes / elem_size_in_int4;
395+
int divisor = toh_C / elem_size_in_int4;
396+
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
362397
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
363398
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
399+
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
364400
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
365401
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
366402
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
403+
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
367404
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
405+
NC /= divisor;
406+
if (diagnostics) {
407+
printf("divisor=%d\n",divisor);
408+
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
409+
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
410+
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
411+
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
412+
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
413+
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
414+
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
415+
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
416+
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
417+
}
368418
void *kernelArgs[] = {
369419
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
370420
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
421+
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
371422
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
372423
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
373424
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
425+
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
374426
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
375427
&NC, &NH, &NW,
376428
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
@@ -381,12 +433,15 @@ void push_pull_halos_1d(
381433
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
382434
} else {
383435
// cannot do int4 transfers
436+
if (diagnostics) printf("CAN NOT DO INT4\n");
384437
void *kernelArgs[] = {
385438
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
386439
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
440+
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
387441
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
388442
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
389443
&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
444+
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
390445
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
391446
&NC, &NH, &NW,
392447
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p

apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory {
2424
void free_raw(int64_t raw);
2525
at::Tensor get_raw_ipc_address(int64_t raw);
2626
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
27-
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape);
28-
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape);
29-
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape);
27+
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
28+
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
29+
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
3030
void push_pull_halos_1d(
31+
bool diagnostics,
3132
bool explicit_nhwc,
3233
int numSM, // number of SMs to use
3334
at::Tensor top_out_halo, // top output halo in sender device memory
3435
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
36+
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
3537
at::Tensor top_inp_halo, // top input halo in receiver device memory
3638
at::Tensor btm_out_halo, // btm output halo in sender device memory
3739
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
40+
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
3841
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
3942
at::Tensor top_signal, // top input signal in receiver device memory
4043
at::Tensor btm_signal, // btm input signal in receiver device memory

0 commit comments

Comments
 (0)