-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathhalo_exchangers.py
More file actions
276 lines (256 loc) · 10.9 KB
/
halo_exchangers.py
File metadata and controls
276 lines (256 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
self.group_size = len(ranks)
self.ranks = ranks
self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (
rank_in_group + self.group_size - 1
) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group - 1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group + 1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.comm = comm
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
N, Hh, W, C = list(left_output_halo.shape)
send_halos = torch.empty(
(N, 2 * Hh, W, C),
dtype=left_output_halo.dtype,
device=left_output_halo.device,
)
send_halos[:, :Hh, :, :].copy_(left_output_halo)
send_halos[:, Hh:, :, :].copy_(right_output_halo)
all_halos = torch.empty(
(N, 2 * Hh * self.group_size, W, C),
dtype=left_output_halo.dtype,
device=left_output_halo.device,
)
all_halos = [
all_halos[:, i * 2 * Hh : (i + 1) * 2 * Hh, :, :] for i in range(self.group_size)
]
torch.distributed.all_gather(all_halos, send_halos, group=self.comm, no_copy=True)
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:, Hh:, :, :]
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:, :Hh, :, :]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert torch.distributed.get_rank() == self.ranks[self.rank_in_group], (
"ranks[%d](%d) != torch.distributed.get_rank()(%d)"
% (
self.rank_in_group,
self.ranks[self.rank_in_group],
torch.distributed.get_rank(),
)
)
self.handle = inc.init_nccl_comm(
nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size()
)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(
self.handle,
self.left_rank,
self.right_rank,
left_output_halo,
right_output_halo,
)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(
self.handle,
self.left_rank,
self.right_rank,
left_output_halo,
right_output_halo,
left_input_halo,
right_input_halo,
)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
def _allocate_peer_tensor(self, halo):
# Compute size in bytes
# Note: Pad buffer so each CUDA block gets required buffer size
size = 4 * halo.numel() * halo.element_size()
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
size = (size + size_per_block - 1) // size_per_block * size_per_block
# Construct dtype peer buffer with desired size
shape = [1, 1, 1, size // halo.element_size()]
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = (
left_output_halo.is_contiguous(memory_format=torch.channels_last)
and not self.explicit_nhwc
)
left_tx = self._allocate_peer_tensor(left_input_halo)
right_tx = self._allocate_peer_tensor(right_input_halo)
pm.push_pull_halos_1d(
self.diagnostics,
self.explicit_nhwc,
self.numSM,
self.rank_in_group,
self.left_zero,
left_output_halo,
left_tx[self.rank_in_group],
right_tx[self.wrap_around_left_rank_in_group],
left_input_halo,
self.right_zero,
right_output_halo,
right_tx[self.rank_in_group],
left_tx[self.wrap_around_right_rank_in_group],
right_input_halo,
)
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N, H, W, C = list(y.shape)
if H_split:
padded_shape = [N, H + 2 * half_halo, W, C]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.contiguous_format,
)
yleft = ypad[:, :half_halo, :, :]
ymid = ypad[:, half_halo : H + half_halo, :, :]
yright = ypad[:, H + half_halo : H + 2 * half_halo, :, :]
oleft = y[:, :half_halo, :, :]
oright = y[:, H - half_halo :, :, :]
else:
padded_shape = [N, H, W + 2 * half_halo, C]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.contiguous_format,
)
yleft = ypad[:, :, :half_halo, :]
ymid = ypad[:, :, half_halo : W + half_halo, :]
yright = ypad[:, :, W + half_halo : W + 2 * half_halo, :]
oleft = y[:, :, :half_halo, :]
oright = y[:, :, W - half_halo :, :]
else:
N, C, H, W = list(y.shape)
if H_split:
padded_shape = [N, C, H + 2 * half_halo, W]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.channels_last,
)
yleft = ypad[:, :, :half_halo, :]
ymid = ypad[:, :, half_halo : H + half_halo, :]
yright = ypad[:, :, H + half_halo : H + 2 * half_halo, :]
oleft = y[:, :, :half_halo, :]
oright = y[:, :, H - half_halo :, :]
else:
padded_shape = [N, C, H, W + 2 * half_halo]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.channels_last,
)
yleft = ypad[:, :, :, :half_halo]
ymid = ypad[:, :, :, half_halo : W + half_halo]
yright = ypad[:, :, :, W + half_halo : W + 2 * half_halo]
oleft = y[:, :, :, :half_halo]
oright = y[:, :, :, W - half_halo :]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)