Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0bd620e

Browse files
committed
Optimize peer memory halo exchange kernel
Use an alternating double-buffer scheme to allow push-only communication. Remove cooperative group syncs and memory fences.
1 parent 3229126 commit 0bd620e

4 files changed

Lines changed: 401 additions & 408 deletions

File tree

apex/contrib/bottleneck/halo_exchangers.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,28 +88,37 @@ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_inp
8888
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
8989

9090
class HaloExchangerPeer(HaloExchanger):
91-
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
91+
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
9292
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
9393
self.diagnostics = False
9494
self.explicit_nhwc = explicit_nhwc
9595
self.numSM = numSM
9696
self.peer_pool = peer_pool
97-
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
98-
self.signals[self.rank_in_group].zero_()
97+
98+
def _allocate_peer_tensor(self, halo):
99+
100+
# Compute size in bytes
101+
# Note: Pad buffer so each CUDA block gets required buffer size
102+
size = 4 * halo.numel() * halo.element_size()
103+
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
104+
size = (size + size_per_block - 1) // size_per_block * size_per_block
105+
106+
# Construct dtype peer buffer with desired size
107+
shape = [1, 1, 1, size // halo.element_size()]
108+
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
99109

100110
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
101111
inplace = False if left_input_halo is None and right_input_halo is None else True
102112
if not inplace:
103113
left_input_halo = torch.empty_like(right_output_halo)
104114
right_input_halo = torch.empty_like(left_output_halo)
105115
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
106-
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
107-
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
116+
left_tx = self._allocate_peer_tensor(left_input_halo)
117+
right_tx = self._allocate_peer_tensor(right_input_halo)
108118
pm.push_pull_halos_1d(
109-
self.diagnostics, self.explicit_nhwc, self.numSM,
119+
self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group,
110120
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
111121
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
112-
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
113122
)
114123
if not inplace:
115124
return left_input_halo, right_input_halo

0 commit comments

Comments
 (0)