Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 88914a5

Browse files
committed
Add halo correction kernel for bprop
1 parent 705aa35 commit 88914a5

3 files changed

Lines changed: 760 additions & 80 deletions

File tree

apex/contrib/bottleneck/bottleneck.py

Lines changed: 119 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
268268
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
269269
if spatial_method == 1:
270270
# overlap mid convolution with halo transfer
271-
if spatial_group_rank > 0:
272-
with torch.cuda.stream(stream1):
273-
if explicit_nhwc:
274-
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
275-
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
276-
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
277-
else:
278-
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
279-
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
280-
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
281-
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
282271
if spatial_group_rank < spatial_group_size-1:
283272
stream2.wait_stream(stream1)
284273
with torch.cuda.stream(stream2):
@@ -291,6 +280,17 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
291280
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
292281
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
293282
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
283+
if spatial_group_rank > 0:
284+
with torch.cuda.stream(stream1):
285+
if explicit_nhwc:
286+
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
287+
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
288+
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
289+
else:
290+
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
291+
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
292+
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
293+
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
294294
inc.add_delay(10)
295295
elif spatial_method != 2 and spatial_method != 3:
296296
assert(False), "spatial_method must be 1, 2 or 3"
@@ -329,13 +329,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
329329
# to wait for out2_mask to finish, but itself has to finish before
330330
# the first kernel of _forward_rest can launch.
331331
# At least we can overlap the two halo correction kernels.
332-
if spatial_group_rank > 0:
333-
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
334-
with torch.cuda.stream(stream1):
335-
w1by3 = args[2][:,:1,:,:].clone()
336-
top_out1_halo = top_out1_halo.clone()
337-
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
338-
top_out2_halo.copy_(top_out2)
339332
if spatial_group_rank < spatial_group_size-1:
340333
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
341334
with torch.cuda.stream(stream2):
@@ -344,9 +337,16 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
344337
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
345338
btm_out2_halo.copy_(btm_out2)
346339
if spatial_group_rank > 0:
347-
torch.cuda.current_stream().wait_stream(stream1)
340+
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
341+
with torch.cuda.stream(stream1):
342+
w1by3 = args[2][:,:1,:,:].clone()
343+
top_out1_halo = top_out1_halo.clone()
344+
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
345+
top_out2_halo.copy_(top_out2)
348346
if spatial_group_rank < spatial_group_size-1:
349347
torch.cuda.current_stream().wait_stream(stream2)
348+
if spatial_group_rank > 0:
349+
torch.cuda.current_stream().wait_stream(stream1)
350350

351351
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
352352
# save halos for backward pass
@@ -365,6 +365,8 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
365365
ctx.spatial_group_rank = spatial_group_rank
366366
ctx.spatial_halo_exchanger = spatial_halo_exchanger
367367
ctx.spatial_method = spatial_method
368+
ctx.thresholdTop = thresholdTop
369+
ctx.thresholdBottom = thresholdBottom
368370
ctx.stream1 = stream1
369371
ctx.stream2 = stream2
370372
ctx.stream3 = stream3
@@ -414,50 +416,55 @@ def backward(ctx, grad_o):
414416
with torch.cuda.stream(ctx.stream1):
415417
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
416418
# copy halos to send buffer
417-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
418-
ctx.stream2.wait_stream(ctx.stream1)
419-
with torch.cuda.stream(ctx.stream2):
420-
if ctx.explicit_nhwc:
421-
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
422-
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
423-
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
424-
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
425-
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
426-
btm_relu_halo[:,2:,:,:].zero_()
427-
else:
428-
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
429-
btm_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
430-
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
431-
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
432-
btm_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
433-
btm_relu_halo[:,:,2:,:].zero_()
434-
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
435-
if ctx.explicit_nhwc:
436-
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
437-
else:
438-
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
439-
if ctx.spatial_group_rank > 0:
440-
with torch.cuda.stream(ctx.stream1):
441-
if ctx.explicit_nhwc:
442-
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
443-
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
444-
top_fat_halo[:,:1,:,:].copy_(top_halo)
445-
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
446-
top_relu_halo[:,:1,:,:].zero_()
447-
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
448-
else:
449-
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
450-
top_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
451-
top_fat_halo[:,:,:1,:].copy_(top_halo)
452-
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
453-
top_relu_halo[:,:,:1,:].zero_()
454-
top_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
455-
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
456-
if ctx.explicit_nhwc:
457-
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
458-
else:
459-
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
460-
inc.add_delay(10)
419+
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
420+
# 1 -> halo recompute approach
421+
# 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
422+
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
423+
ctx.stream2.wait_stream(ctx.stream1)
424+
with torch.cuda.stream(ctx.stream2):
425+
if ctx.explicit_nhwc:
426+
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
427+
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
428+
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
429+
btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
430+
btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
431+
btm_fat_relu_halo[:,2:,:,:].zero_()
432+
else:
433+
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
434+
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
435+
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
436+
btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
437+
btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
438+
btm_fat_relu_halo[:,:,2:,:].zero_()
439+
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
440+
if ctx.explicit_nhwc:
441+
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
442+
else:
443+
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
444+
if ctx.spatial_group_rank > 0:
445+
with torch.cuda.stream(ctx.stream1):
446+
if ctx.explicit_nhwc:
447+
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
448+
top_fat_halo[:,:1,:,:].copy_(top_halo)
449+
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
450+
top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
451+
top_fat_relu_halo[:,:1,:,:].zero_()
452+
top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
453+
else:
454+
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
455+
top_fat_halo[:,:,:1,:].copy_(top_halo)
456+
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
457+
top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
458+
top_fat_relu_halo[:,:,:1,:].zero_()
459+
top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
460+
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
461+
if ctx.explicit_nhwc:
462+
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
463+
else:
464+
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
465+
inc.add_delay(10)
466+
elif ctx.spatial_method != 3:
467+
assert(False), "spatial_method must be 1, 2 or 3"
461468

462469
with torch.cuda.stream(wgrad2_stream):
463470
if ctx.spatial_group_size > 1:
@@ -466,28 +473,62 @@ def backward(ctx, grad_o):
466473
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
467474

468475
# compute grad_out1 for internal cells
469-
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
476+
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
477+
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
478+
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
479+
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
470480

471481
# apply halo cells to grad_out1
472482
if ctx.spatial_group_size > 1:
473483
w = t_list[2]
474484
z = t_list[4]
475485
relu1 = t_list[12]
476486
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
477-
if ctx.spatial_group_rank > 0:
478-
torch.cuda.current_stream().wait_stream(ctx.stream1)
479-
if ctx.explicit_nhwc:
480-
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
481-
else:
482-
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
483-
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
484-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
485-
torch.cuda.current_stream().wait_stream(ctx.stream2)
486-
if ctx.explicit_nhwc:
487-
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
488-
else:
489-
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
490-
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
487+
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
488+
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
489+
torch.cuda.current_stream().wait_stream(ctx.stream2)
490+
if ctx.explicit_nhwc:
491+
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
492+
else:
493+
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
494+
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
495+
if ctx.spatial_group_rank > 0:
496+
torch.cuda.current_stream().wait_stream(ctx.stream1)
497+
if ctx.explicit_nhwc:
498+
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
499+
else:
500+
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
501+
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
502+
elif ctx.spatial_method == 3:
503+
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
504+
if ctx.explicit_nhwc:
505+
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
506+
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
507+
else:
508+
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
509+
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
510+
w1by3 = w[:,:1,:,:].clone()
511+
ctx.stream1.wait_stream(ctx.stream2) # wait for halo transfers to finish
512+
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
513+
with torch.cuda.stream(ctx.stream1):
514+
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
515+
btm_grad_out1.copy_(btm_grad_out1_halo)
516+
if ctx.spatial_group_rank > 0:
517+
if ctx.explicit_nhwc:
518+
top_relu_halo = relu1[:,:1,:,:].clone()
519+
top_grad_out1 = grad_out1[:,:1,:,:]
520+
else:
521+
top_relu_halo = relu1[:,:,:1,:].clone()
522+
top_grad_out1 = grad_out1[:,:,:1,:]
523+
w1by3 = w[:,2:,:,:].clone()
524+
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
525+
with torch.cuda.stream(ctx.stream1):
526+
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
527+
top_grad_out1.copy_(top_grad_out1_halo)
528+
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
529+
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
530+
if ctx.spatial_group_rank > 0:
531+
torch.cuda.current_stream().wait_stream(ctx.stream1)
491532

492533
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
493534
torch.cuda.current_stream().wait_stream(wgrad2_stream)

apex/contrib/bottleneck/bottleneck_module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
161161
spatial_group_rank = rank
162162
spatial_communicator = None
163163
spatial_halo_exchanger = halex
164-
spatial_method = 2 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
164+
spatial_method = 3 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
165165
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
166166
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
167167

0 commit comments

Comments
 (0)