@@ -30,15 +30,23 @@ void deleter(void* ptr)
3030*/
3131
3232template <class T >
33- at::Tensor blob_view (T* raw_ptr, std::vector<int64_t > shape, const at::TensorOptions& options)
33+ at::Tensor blob_view (T* raw_ptr, std::vector<int64_t > shape, const at::TensorOptions& options, bool channels_last )
3434{
35- std::vector<int64_t > strides (shape.size ());
3635 size_t size = 1 ;
37- int idx = strides.size ();
38- for (auto it = shape.rbegin (); it != shape.rend (); ++it)
39- {
40- strides[--idx] = size;
41- size *= *it;
36+ std::vector<int64_t > strides (shape.size ());
37+ if (channels_last) {
38+ assert (shape.size () == 4 );
39+ strides[0 ] = shape[1 ]*shape[2 ]*shape[3 ];
40+ strides[1 ] = 1 ;
41+ strides[2 ] = shape[1 ]*shape[3 ];
42+ strides[3 ] = shape[1 ];
43+ } else {
44+ int idx = strides.size ();
45+ for (auto it = shape.rbegin (); it != shape.rend (); ++it)
46+ {
47+ strides[--idx] = size;
48+ size *= *it;
49+ }
4250 }
4351 size *= sizeof (T);
4452 // TODO: Implement dynamic reuse of pooled peer memory.
@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
139147 }
140148}
141149
150+ template <bool wait, bool clear>
142151__device__ void dual_signal_wait_clear (
143152 volatile int * signal1_flag, volatile int * wait1_flag,
144153 volatile int * signal2_flag, volatile int * wait2_flag,
145- const int v1, const int v2, const int v3, const int v4,
146- const bool clear
154+ const int v1, const int v2, const int v3, const int v4
147155 )
148156{
149157 register int r1, r2, r3, r4, r5, r6, r7, r8;
@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear(
152160 if (is_main_thread) {
153161 asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (signal1_flag), " r" (v1), " r" (v2), " r" (v3), " r" (v4) : " memory" );
154162 asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (signal2_flag), " r" (v1), " r" (v2), " r" (v3), " r" (v4) : " memory" );
155- do {
156- asm volatile (" ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : " =r" (r1), " =r" (r2), " =r" (r3), " =r" (r4) : " l" (wait1_flag) : " memory" );
157- asm volatile (" ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : " =r" (r5), " =r" (r6), " =r" (r7), " =r" (r8) : " l" (wait2_flag) : " memory" );
158- } while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
163+ if (wait) {
164+ do {
165+ asm volatile (" ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : " =r" (r1), " =r" (r2), " =r" (r3), " =r" (r4) : " l" (wait1_flag) : " memory" );
166+ asm volatile (" ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : " =r" (r5), " =r" (r6), " =r" (r7), " =r" (r8) : " l" (wait2_flag) : " memory" );
167+ } while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
168+ }
159169 }
160170 cg::this_grid ().sync ();
161- // optionally clear wait flag
162- if (clear && is_main_thread) {
163- r1 = 0 ; r2 = 0 ; r3 = 0 ; r4 = 0 ;
164- asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (wait1_flag), " r" (r1), " r" (r2), " r" (r3), " r" (r4) : " memory" );
165- asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (wait2_flag), " r" (r1), " r" (r2), " r" (r3), " r" (r4) : " memory" );
171+ if (clear) {
172+ if (is_main_thread) {
173+ r1 = 0 ; r2 = 0 ; r3 = 0 ; r4 = 0 ;
174+ asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (wait1_flag), " r" (r1), " r" (r2), " r" (r3), " r" (r4) : " memory" );
175+ asm volatile (" st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: " l" (wait2_flag), " r" (r1), " r" (r2), " r" (r3), " r" (r4) : " memory" );
176+ }
166177 }
167178}
168179
@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
173184__global__ void push_pull_halos_1d_kernel (
174185 // top halo,
175186 const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
176- T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top tx buffer
187+ T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
188+ T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
177189 T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
178190 // btm halo
179- const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // top output halo
180- T* box, int box_stride_C, int box_stride_H, int box_stride_W, // top tx buffer
181- T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // top input halo
191+ const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
192+ T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
193+ T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
194+ T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
182195 // dimensions
183196 int NC, int NH, int NW,
184197 // signals
@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
194207 strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
195208 // signal to top and btm neigbhbors that output halos are ready to be read
196209 // the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
197- dual_signal_wait_clear (signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720 , 840868300 , -225529332 , 281513358 , true );
210+ dual_signal_wait_clear< true , true > (signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720 , 840868300 , -225529332 , 281513358 );
198211 // pull top halo from transfer buffer in peer memory to input
199- strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, tih, tih_stride_C, tih_stride_H, tih_stride_W , NC, NH, NW);
212+ strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W , NC, NH, NW);
200213 // pull btm halo from transfer buffer in peer memory to input
201- strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, bih, bih_stride_C, bih_stride_H, bih_stride_W , NC, NH, NW);
214+ strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W , NC, NH, NW);
202215}
203216
204217}
@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
246259 return results;
247260}
248261
249- at::Tensor blob_view_half (int64_t raw, std::vector<int64_t > shape)
262+ at::Tensor blob_view_half (int64_t raw, std::vector<int64_t > shape, bool channels_last )
250263{
251- return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype (torch::kFloat16 ).device (torch::kCUDA ));
264+ return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype (torch::kFloat16 ).device (torch::kCUDA ), channels_last );
252265}
253266
254- at::Tensor blob_view_float (int64_t raw, std::vector<int64_t > shape)
267+ at::Tensor blob_view_float (int64_t raw, std::vector<int64_t > shape, bool channels_last )
255268{
256- return blob_view<float >((float *)raw, shape, torch::dtype (torch::kFloat16 ).device (torch::kCUDA ));
269+ return blob_view<float >((float *)raw, shape, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ), channels_last );
257270}
258271
259- at::Tensor blob_view_int (int64_t raw, std::vector<int64_t > shape)
272+ at::Tensor blob_view_int (int64_t raw, std::vector<int64_t > shape, bool channels_last )
260273{
261- return blob_view<int >((int *)raw, shape, torch::dtype (torch::kFloat16 ).device (torch::kCUDA ));
274+ return blob_view<int >((int *)raw, shape, torch::dtype (torch::kInt32 ).device (torch::kCUDA ), channels_last );
262275}
263276
264277void push_pull_halos_1d (
278+ bool diagnostics,
265279 bool explicit_nhwc,
266280 int numSM, // number of SMs to use
267281 at::Tensor top_out_halo, // top output halo in sender device memory
268282 at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
283+ at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
269284 at::Tensor top_inp_halo, // top input halo in receiver device memory
270285 at::Tensor btm_out_halo, // btm output halo in sender device memory
271286 at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
287+ at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
272288 at::Tensor btm_inp_halo, // btm input halo in receiver device memory
273289 at::Tensor top_signal, // top input signal in receiver device memory
274290 at::Tensor btm_signal, // btm input signal in receiver device memory
@@ -278,9 +294,11 @@ void push_pull_halos_1d(
278294 // basic checks of inputs
279295 TORCH_CHECK (top_out_halo.is_cuda ());
280296 TORCH_CHECK (top_out_tx.is_cuda ());
297+ TORCH_CHECK (top_inp_tx.is_cuda ());
281298 TORCH_CHECK (top_inp_halo.is_cuda ());
282299 TORCH_CHECK (btm_out_halo.is_cuda ());
283300 TORCH_CHECK (btm_out_tx.is_cuda ());
301+ TORCH_CHECK (btm_inp_tx.is_cuda ());
284302 TORCH_CHECK (btm_inp_halo.is_cuda ());
285303 TORCH_CHECK (top_signal.is_cuda ());
286304 TORCH_CHECK (btm_signal.is_cuda ());
@@ -291,46 +309,56 @@ void push_pull_halos_1d(
291309 tensor_shape (top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
292310 int tox_N, tox_C, tox_H, tox_W;
293311 tensor_shape (top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
312+ int tix_N, tix_C, tix_H, tix_W;
313+ tensor_shape (top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
294314 int tih_N, tih_C, tih_H, tih_W;
295315 tensor_shape (top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
296316 TORCH_CHECK (
297- (toh_N == tox_N && tox_N == tih_N) &&
298- (toh_C == tox_C && tox_C == tih_C) &&
299- (toh_H == tox_H && tox_H == tih_H) &&
300- (toh_W == tox_W && tox_W == tih_W));
317+ (toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
318+ (toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
319+ (toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
320+ (toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
301321 int boh_N, boh_C, boh_H, boh_W;
302322 tensor_shape (btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
303323 int box_N, box_C, box_H, box_W;
304324 tensor_shape (btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
325+ int bix_N, bix_C, bix_H, bix_W;
326+ tensor_shape (btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
305327 int bih_N, bih_C, bih_H, bih_W;
306328 tensor_shape (btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
307329 TORCH_CHECK (
308- (boh_N == box_N && box_N == bih_N) &&
309- (boh_C == box_C && box_C == bih_C) &&
310- (boh_H == box_H && box_H == bih_H) &&
311- (boh_W == box_W && box_W == bih_W));
330+ (boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
331+ (boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
332+ (boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
333+ (boh_W == box_W && box_W == bix_W && bix_W == bih_W));
312334 TORCH_CHECK (
313335 (toh_N == boh_N) &&
314336 (toh_C == boh_C) &&
315337 (toh_H == boh_H) &&
316338 (toh_W == boh_W));
317339 int NC=toh_C, NH=toh_H, NW=toh_W;
340+ if (diagnostics) printf (" NC=%d, NH=%d, NW=%d\n " ,NC,NH,NW);
318341
319342 int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
320343 tensor_strides (top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
321344 int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
322345 tensor_strides (top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
346+ int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
347+ tensor_strides (top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
323348 int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
324349 tensor_strides (top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
325350 int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
326351 tensor_strides (btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
327352 int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
328353 tensor_strides (btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
354+ int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
355+ tensor_strides (btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
329356 int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
330357 tensor_strides (btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
331358
332359 // determine if nhwc
333360 auto is_nhwc = (toh_stride_C == 1 ) ? true : false ;
361+ if (diagnostics) printf (" is_nhwc = %s\n " ,is_nhwc?" true" :" false" );
334362
335363 // figure out launch parameters
336364 int device;
@@ -342,35 +370,59 @@ void push_pull_halos_1d(
342370 const int numThreads = 128 ;
343371 dim3 block (numThreads,1 ,1 );
344372 AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Half, top_out_halo.scalar_type (), " push_pull_halos_1d_kernel" , [&]{
373+ if (diagnostics) printf (" size(scalar_t) = %d\n " ,sizeof (scalar_t ));
345374 scalar_t * toh_p = top_out_halo.data_ptr <scalar_t >();
346375 scalar_t * tox_p = top_out_tx.data_ptr <scalar_t >();
376+ scalar_t * tix_p = top_inp_tx.data_ptr <scalar_t >();
347377 scalar_t * tih_p = top_inp_halo.data_ptr <scalar_t >();
348378 scalar_t * boh_p = btm_out_halo.data_ptr <scalar_t >();
349379 scalar_t * box_p = btm_out_tx.data_ptr <scalar_t >();
380+ scalar_t * bix_p = btm_inp_tx.data_ptr <scalar_t >();
350381 scalar_t * bih_p = btm_inp_halo.data_ptr <scalar_t >();
351- int * top_signal_p = top_signal.data_ptr <int >();
352- int * btm_signal_p = btm_signal.data_ptr <int >() + 4 ;
382+ if (diagnostics) printf (" waypoint1\n " );
383+ int * top_signal_p = top_signal.data_ptr <int >() + 4 ;
384+ int * btm_signal_p = btm_signal.data_ptr <int >();
353385 int * top_wait_p = waits.data_ptr <int >();
354386 int * btm_wait_p = waits.data_ptr <int >() + 4 ;
387+ if (diagnostics) printf (" waypoint2\n " );
355388
356389 // do int4 vector loads if channel count permits
357390 int elem_size_in_bytes = toh_C * sizeof (scalar_t );
358391 int elem_size_in_int4 = (elem_size_in_bytes / 16 );
392+ if (diagnostics) printf (" elem_size_in_bytes = %d, elem_size_in_int4 = %d\n " ,elem_size_in_bytes,elem_size_in_int4);
359393 if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
360394 // can do int4 transfers
361- int divisor = elem_size_in_bytes / elem_size_in_int4;
395+ int divisor = toh_C / elem_size_in_int4;
396+ if (diagnostics) printf (" CAN DO INT4 :: divisor = %d\n " ,divisor);
362397 toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
363398 tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
399+ tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
364400 tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
365401 boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
366402 box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
403+ bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
367404 bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
405+ NC /= divisor;
406+ if (diagnostics) {
407+ printf (" divisor=%d\n " ,divisor);
408+ printf (" toh_stride :: N=%d, C=%d, H=%d, W=%d\n " ,toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
409+ printf (" tox_stride :: N=%d, C=%d, H=%d, W=%d\n " ,tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
410+ printf (" tix_stride :: N=%d, C=%d, H=%d, W=%d\n " ,tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
411+ printf (" tih_stride :: N=%d, C=%d, H=%d, W=%d\n " ,tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
412+ printf (" boh_stride :: N=%d, C=%d, H=%d, W=%d\n " ,boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
413+ printf (" box_stride :: N=%d, C=%d, H=%d, W=%d\n " ,box_stride_N,box_stride_C,box_stride_H,box_stride_W);
414+ printf (" bix_stride :: N=%d, C=%d, H=%d, W=%d\n " ,bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
415+ printf (" bih_stride :: N=%d, C=%d, H=%d, W=%d\n " ,bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
416+ printf (" NC=%d, NH=%d, NW=%d\n " ,NC,NH,NW);
417+ }
368418 void *kernelArgs[] = {
369419 (int4 **)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
370420 (int4 **)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
421+ (int4 **)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
371422 (int4 **)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
372423 (int4 **)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
373424 (int4 **)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
425+ (int4 **)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
374426 (int4 **)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
375427 &NC, &NH, &NW,
376428 &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
@@ -381,12 +433,15 @@ void push_pull_halos_1d(
381433 cudaLaunchCooperativeKernel ((void *)push_pull_halos_1d_kernel<int4 ,true >, grid, block, kernelArgs, 0 , current_stream);
382434 } else {
383435 // cannot do int4 transfers
436+ if (diagnostics) printf (" CAN NOT DO INT4\n " );
384437 void *kernelArgs[] = {
385438 &toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
386439 &tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
440+ &tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
387441 &tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
388442 &boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
389443 &box_p, &box_stride_C, &box_stride_H, &box_stride_W,
444+ &bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
390445 &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
391446 &NC, &NH, &NW,
392447 &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
0 commit comments