@@ -268,17 +268,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
268268 spatial_halo_exchanger .left_right_halo_exchange (out1 [:,:,:1 ,:], out1 [:,:,Hs - 1 :,:], top_out1_halo , btm_out1_halo )
269269 if spatial_method == 1 :
270270 # overlap mid convolution with halo transfer
271- if spatial_group_rank > 0 :
272- with torch .cuda .stream (stream1 ):
273- if explicit_nhwc :
274- top_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = out1 .dtype ,device = out1 .device )
275- top_fat_halo [:,:1 ,:,:].copy_ (top_out1_halo )
276- top_fat_halo [:,1 :3 ,:,:].copy_ (out1 [:,:2 ,:,:])
277- else :
278- top_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = out1 .dtype ,device = out1 .device )
279- top_fat_halo [:,:,:1 ,:].copy_ (top_out1_halo )
280- top_fat_halo [:,:,1 :3 ,:].copy_ (out1 [:,:,:2 ,:])
281- top_out2 = fast_bottleneck .forward_out2_halo (explicit_nhwc , top_fat_halo , args )
282271 if spatial_group_rank < spatial_group_size - 1 :
283272 stream2 .wait_stream (stream1 )
284273 with torch .cuda .stream (stream2 ):
@@ -291,6 +280,17 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
291280 btm_fat_halo [:,:,0 :2 ,:].copy_ (out1 [:,:,Hs - 2 :,:])
292281 btm_fat_halo [:,:,2 :,:].copy_ (btm_out1_halo )
293282 btm_out2 = fast_bottleneck .forward_out2_halo (explicit_nhwc , btm_fat_halo , args )
283+ if spatial_group_rank > 0 :
284+ with torch .cuda .stream (stream1 ):
285+ if explicit_nhwc :
286+ top_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = out1 .dtype ,device = out1 .device )
287+ top_fat_halo [:,:1 ,:,:].copy_ (top_out1_halo )
288+ top_fat_halo [:,1 :3 ,:,:].copy_ (out1 [:,:2 ,:,:])
289+ else :
290+ top_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = out1 .dtype ,device = out1 .device )
291+ top_fat_halo [:,:,:1 ,:].copy_ (top_out1_halo )
292+ top_fat_halo [:,:,1 :3 ,:].copy_ (out1 [:,:,:2 ,:])
293+ top_out2 = fast_bottleneck .forward_out2_halo (explicit_nhwc , top_fat_halo , args )
294294 inc .add_delay (10 )
295295 elif spatial_method != 2 and spatial_method != 3 :
296296 assert (False ), "spatial_method must be 1, 2 or 3"
@@ -329,13 +329,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
329329 # to wait for out2_mask to finish, but itself has to finish before
330330 # the first kernel of _forward_rest can launch.
331331 # At least we can overlap the two halo correction kernels.
332- if spatial_group_rank > 0 :
333- stream1 .wait_stream (torch .cuda .current_stream ()) # wait for *_out2_mask to finish
334- with torch .cuda .stream (stream1 ):
335- w1by3 = args [2 ][:,:1 ,:,:].clone ()
336- top_out1_halo = top_out1_halo .clone ()
337- top_out2 = fast_bottleneck .forward_out2_halo_corr (explicit_nhwc , top_out1_halo , args , w1by3 , top_out2_halo .clone ())
338- top_out2_halo .copy_ (top_out2 )
339332 if spatial_group_rank < spatial_group_size - 1 :
340333 stream2 .wait_stream (torch .cuda .current_stream ()) # wait for *_out2_mask to finish
341334 with torch .cuda .stream (stream2 ):
@@ -344,9 +337,16 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
344337 btm_out2 = fast_bottleneck .forward_out2_halo_corr (explicit_nhwc , btm_out1_halo , args , w1by3 , btm_out2_halo .clone ())
345338 btm_out2_halo .copy_ (btm_out2 )
346339 if spatial_group_rank > 0 :
347- torch .cuda .current_stream ().wait_stream (stream1 )
340+ stream1 .wait_stream (torch .cuda .current_stream ()) # wait for *_out2_mask to finish
341+ with torch .cuda .stream (stream1 ):
342+ w1by3 = args [2 ][:,:1 ,:,:].clone ()
343+ top_out1_halo = top_out1_halo .clone ()
344+ top_out2 = fast_bottleneck .forward_out2_halo_corr (explicit_nhwc , top_out1_halo , args , w1by3 , top_out2_halo .clone ())
345+ top_out2_halo .copy_ (top_out2 )
348346 if spatial_group_rank < spatial_group_size - 1 :
349347 torch .cuda .current_stream ().wait_stream (stream2 )
348+ if spatial_group_rank > 0 :
349+ torch .cuda .current_stream ().wait_stream (stream1 )
350350
351351 fast_bottleneck .forward_rest (explicit_nhwc , stride_1x1 , args , outputs )
352352 # save halos for backward pass
@@ -365,6 +365,8 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
365365 ctx .spatial_group_rank = spatial_group_rank
366366 ctx .spatial_halo_exchanger = spatial_halo_exchanger
367367 ctx .spatial_method = spatial_method
368+ ctx .thresholdTop = thresholdTop
369+ ctx .thresholdBottom = thresholdBottom
368370 ctx .stream1 = stream1
369371 ctx .stream2 = stream2
370372 ctx .stream3 = stream3
@@ -414,50 +416,55 @@ def backward(ctx, grad_o):
414416 with torch .cuda .stream (ctx .stream1 ):
415417 top_halo , btm_halo = ctx .spatial_halo_exchanger .left_right_halo_exchange (grad_out2 [:,:1 ,:,:], grad_out2 [:,Hs - 1 :,:,:])
416418 # copy halos to send buffer
417- if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
418- ctx .stream2 .wait_stream (ctx .stream1 )
419- with torch .cuda .stream (ctx .stream2 ):
420- if ctx .explicit_nhwc :
421- btm_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
422- btm_relu_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
423- btm_fat_halo [:,:2 ,:,:].copy_ (grad_out2 [:,Hs - 2 :,:,:])
424- btm_fat_halo [:,2 :,:,:].copy_ (btm_halo )
425- btm_relu_halo [:,:2 ,:,:].copy_ (relu1 [:,Hs - 2 :,:,:])
426- btm_relu_halo [:,2 :,:,:].zero_ ()
427- else :
428- btm_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
429- btm_relu_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
430- btm_fat_halo [:,:,:2 ,:].copy_ (grad_out2 [:,:,Hs - 2 :,:])
431- btm_fat_halo [:,:,2 :,:].copy_ (btm_halo )
432- btm_relu_halo [:,:,:2 ,:].copy_ (relu1 [:,:,Hs - 2 :,:])
433- btm_relu_halo [:,:,2 :,:].zero_ ()
434- btm_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , btm_fat_halo , btm_relu_halo )
435- if ctx .explicit_nhwc :
436- btm_grad_out1_halo = btm_grad_out1_halo [:,1 :2 ,:,:]
437- else :
438- btm_grad_out1_halo = btm_grad_out1_halo [:,:,1 :2 ,:]
439- if ctx .spatial_group_rank > 0 :
440- with torch .cuda .stream (ctx .stream1 ):
441- if ctx .explicit_nhwc :
442- top_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
443- top_relu_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
444- top_fat_halo [:,:1 ,:,:].copy_ (top_halo )
445- top_fat_halo [:,1 :,:,:].copy_ (grad_out2 [:,:2 ,:,:])
446- top_relu_halo [:,:1 ,:,:].zero_ ()
447- top_relu_halo [:,1 :,:,:].copy_ (relu1 [:,:2 ,:,:])
448- else :
449- top_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
450- top_relu_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
451- top_fat_halo [:,:,:1 ,:].copy_ (top_halo )
452- top_fat_halo [:,:,1 :,:].copy_ (grad_out2 [:,:,:2 ,:])
453- top_relu_halo [:,:,:1 ,:].zero_ ()
454- top_relu_halo [:,:,1 :,:].copy_ (relu1 [:,:,:2 ,:])
455- top_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , top_fat_halo , top_relu_halo )
456- if ctx .explicit_nhwc :
457- top_grad_out1_halo = top_grad_out1_halo [:,1 :2 ,:,:]
458- else :
459- top_grad_out1_halo = top_grad_out1_halo [:,:,1 :2 ,:]
460- inc .add_delay (10 )
419+ if ctx .spatial_method == 1 or ctx .spatial_method == 2 :
420+ # 1 -> halo recompute approach
421+ # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
422+ if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
423+ ctx .stream2 .wait_stream (ctx .stream1 )
424+ with torch .cuda .stream (ctx .stream2 ):
425+ if ctx .explicit_nhwc :
426+ btm_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
427+ btm_fat_halo [:,:2 ,:,:].copy_ (grad_out2 [:,Hs - 2 :,:,:])
428+ btm_fat_halo [:,2 :,:,:].copy_ (btm_halo )
429+ btm_fat_relu_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
430+ btm_fat_relu_halo [:,:2 ,:,:].copy_ (relu1 [:,Hs - 2 :,:,:])
431+ btm_fat_relu_halo [:,2 :,:,:].zero_ ()
432+ else :
433+ btm_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
434+ btm_fat_halo [:,:,:2 ,:].copy_ (grad_out2 [:,:,Hs - 2 :,:])
435+ btm_fat_halo [:,:,2 :,:].copy_ (btm_halo )
436+ btm_fat_relu_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
437+ btm_fat_relu_halo [:,:,:2 ,:].copy_ (relu1 [:,:,Hs - 2 :,:])
438+ btm_fat_relu_halo [:,:,2 :,:].zero_ ()
439+ btm_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , btm_fat_halo , btm_fat_relu_halo )
440+ if ctx .explicit_nhwc :
441+ btm_grad_out1_halo = btm_grad_out1_halo [:,1 :2 ,:,:]
442+ else :
443+ btm_grad_out1_halo = btm_grad_out1_halo [:,:,1 :2 ,:]
444+ if ctx .spatial_group_rank > 0 :
445+ with torch .cuda .stream (ctx .stream1 ):
446+ if ctx .explicit_nhwc :
447+ top_fat_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
448+ top_fat_halo [:,:1 ,:,:].copy_ (top_halo )
449+ top_fat_halo [:,1 :,:,:].copy_ (grad_out2 [:,:2 ,:,:])
450+ top_fat_relu_halo = torch .empty ((N ,3 ,W ,C ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
451+ top_fat_relu_halo [:,:1 ,:,:].zero_ ()
452+ top_fat_relu_halo [:,1 :,:,:].copy_ (relu1 [:,:2 ,:,:])
453+ else :
454+ top_fat_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
455+ top_fat_halo [:,:,:1 ,:].copy_ (top_halo )
456+ top_fat_halo [:,:,1 :,:].copy_ (grad_out2 [:,:,:2 ,:])
457+ top_fat_relu_halo = torch .empty ((N ,C ,3 ,W ),dtype = grad_out2 .dtype ,device = grad_out2 .device )
458+ top_fat_relu_halo [:,:,:1 ,:].zero_ ()
459+ top_fat_relu_halo [:,:,1 :,:].copy_ (relu1 [:,:,:2 ,:])
460+ top_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , top_fat_halo , top_fat_relu_halo )
461+ if ctx .explicit_nhwc :
462+ top_grad_out1_halo = top_grad_out1_halo [:,1 :2 ,:,:]
463+ else :
464+ top_grad_out1_halo = top_grad_out1_halo [:,:,1 :2 ,:]
465+ inc .add_delay (10 )
466+ elif ctx .spatial_method != 3 :
467+ assert (False ), "spatial_method must be 1, 2 or 3"
461468
462469 with torch .cuda .stream (wgrad2_stream ):
463470 if ctx .spatial_group_size > 1 :
@@ -466,28 +473,62 @@ def backward(ctx, grad_o):
466473 wgrad2 = fast_bottleneck .backward_wgrad2 (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , grad_out2 )
467474
468475 # compute grad_out1 for internal cells
469- grad_out1 = fast_bottleneck .backward_grad_out1 (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , grad_out2 )
476+ if ctx .spatial_group_size <= 1 or ctx .spatial_method == 1 or ctx .spatial_method == 2 :
477+ grad_out1 = fast_bottleneck .backward_grad_out1 (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , grad_out2 )
478+ elif ctx .spatial_group_size > 1 and ctx .spatial_method == 3 :
479+ grad_out1 = fast_bottleneck .backward_grad_out1_mask (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , grad_out2 , ctx .thresholdTop , ctx .thresholdBottom )
470480
471481 # apply halo cells to grad_out1
472482 if ctx .spatial_group_size > 1 :
473483 w = t_list [2 ]
474484 z = t_list [4 ]
475485 relu1 = t_list [12 ]
476486 #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
477- if ctx .spatial_group_rank > 0 :
478- torch .cuda .current_stream ().wait_stream (ctx .stream1 )
479- if ctx .explicit_nhwc :
480- grad_out1 [:,:1 ,:,:].copy_ (top_grad_out1_halo )
481- else :
482- grad_out1 [:,:,:1 ,:].copy_ (top_grad_out1_halo )
483- #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
484- if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
485- torch .cuda .current_stream ().wait_stream (ctx .stream2 )
486- if ctx .explicit_nhwc :
487- grad_out1 [:,Hs - 1 :,:,:].copy_ (btm_grad_out1_halo )
488- else :
489- grad_out1 [:,:,Hs - 1 :,:].copy_ (btm_grad_out1_halo )
490- #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
487+ if ctx .spatial_method == 1 or ctx .spatial_method == 2 :
488+ if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
489+ torch .cuda .current_stream ().wait_stream (ctx .stream2 )
490+ if ctx .explicit_nhwc :
491+ grad_out1 [:,Hs - 1 :,:,:].copy_ (btm_grad_out1_halo )
492+ else :
493+ grad_out1 [:,:,Hs - 1 :,:].copy_ (btm_grad_out1_halo )
494+ #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
495+ if ctx .spatial_group_rank > 0 :
496+ torch .cuda .current_stream ().wait_stream (ctx .stream1 )
497+ if ctx .explicit_nhwc :
498+ grad_out1 [:,:1 ,:,:].copy_ (top_grad_out1_halo )
499+ else :
500+ grad_out1 [:,:,:1 ,:].copy_ (top_grad_out1_halo )
501+ #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
502+ elif ctx .spatial_method == 3 :
503+ if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
504+ if ctx .explicit_nhwc :
505+ btm_relu_halo = relu1 [:,Hs - 1 :,:,:].clone ()
506+ btm_grad_out1 = grad_out1 [:,Hs - 1 :,:,:]
507+ else :
508+ btm_relu_halo = relu1 [:,:,Hs - 1 :,:].clone ()
509+ btm_grad_out1 = grad_out1 [:,:,Hs - 1 :,:]
510+ w1by3 = w [:,:1 ,:,:].clone ()
511+ ctx .stream1 .wait_stream (ctx .stream2 ) # wait for halo transfers to finish
512+ ctx .stream2 .wait_stream (torch .cuda .current_stream ()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
513+ with torch .cuda .stream (ctx .stream1 ):
514+ btm_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo_corr (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , w1by3 , grads , btm_halo , btm_relu_halo , btm_grad_out1 .clone ())
515+ btm_grad_out1 .copy_ (btm_grad_out1_halo )
516+ if ctx .spatial_group_rank > 0 :
517+ if ctx .explicit_nhwc :
518+ top_relu_halo = relu1 [:,:1 ,:,:].clone ()
519+ top_grad_out1 = grad_out1 [:,:1 ,:,:]
520+ else :
521+ top_relu_halo = relu1 [:,:,:1 ,:].clone ()
522+ top_grad_out1 = grad_out1 [:,:,:1 ,:]
523+ w1by3 = w [:,2 :,:,:].clone ()
524+ ctx .stream1 .wait_stream (torch .cuda .current_stream ()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
525+ with torch .cuda .stream (ctx .stream1 ):
526+ top_grad_out1_halo = fast_bottleneck .backward_grad_out1_halo_corr (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , w1by3 , grads , top_halo , top_relu_halo , top_grad_out1 .clone ())
527+ top_grad_out1 .copy_ (top_grad_out1_halo )
528+ if ctx .spatial_group_rank < ctx .spatial_group_size - 1 :
529+ torch .cuda .current_stream ().wait_stream (ctx .stream2 ) # wait for halo correction to finish
530+ if ctx .spatial_group_rank > 0 :
531+ torch .cuda .current_stream ().wait_stream (ctx .stream1 )
491532
492533 fast_bottleneck .backward_rest (ctx .explicit_nhwc , ctx .stride_1x1 , t_list , grads , grad_out2 , grad_out1 , wgrad2 )
493534 torch .cuda .current_stream ().wait_stream (wgrad2_stream )
0 commit comments