-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathbatch_norm.py
More file actions
468 lines (436 loc) · 13.7 KB
/
batch_norm.py
File metadata and controls
468 lines (436 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import torch
import numpy as np
from torch.nn.modules.batchnorm import _BatchNorm
import bnp
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
s,
b,
rm,
riv,
mini_m,
mini_riv,
ret_cta,
mom,
epsilon,
fuse_relu,
is_train,
bn_group,
my_data,
pair_data,
magic,
pair_data2,
pair_data3,
fwd_occup,
fwd_grid_x,
bwd_occup,
bwd_grid_x,
multi_stream,
):
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_fwd_nhwc(
x,
s,
b,
rm,
riv,
mini_m,
mini_riv,
ret_cta,
mom,
epsilon,
fuse_relu,
my_data,
pair_data,
pair_data2,
pair_data3,
bn_group,
magic,
fwd_occup,
fwd_grid_x,
multi_stream,
)
return res
else:
return bnp.bn_fwd_eval_nhwc(
x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu
)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dscale, dbias = bnp.bn_bwd_nhwc(
x,
grad_y,
s,
b,
rm,
riv,
mini_m,
mini_riv,
ret_cta,
mom,
epsilon,
fuse_relu,
my_data,
pair_data,
pair_data2,
pair_data3,
bn_group,
magic,
bwd_occup,
bwd_grid_x,
multi_stream,
)
return (
dx,
dscale,
dbias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
z,
s,
b,
rm,
riv,
mini_m,
mini_riv,
grid_dim_y,
ret_cta,
mom,
epsilon,
is_train,
bn_group,
my_data,
pair_data,
magic,
pair_data2,
pair_data3,
fwd_occup,
fwd_grid_x,
bwd_occup,
bwd_grid_x,
multi_stream,
):
if is_train:
bitmask = torch.cuda.IntTensor(((x.numel() + 31) // 32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_addrelu_fwd_nhwc(
x,
z,
s,
b,
rm,
riv,
mini_m,
mini_riv,
bitmask,
ret_cta,
mom,
epsilon,
my_data,
pair_data,
pair_data2,
pair_data3,
bn_group,
magic,
fwd_occup,
fwd_grid_x,
multi_stream,
)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(
x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon
)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(
x,
grad_y,
s,
b,
rm,
riv,
mini_m,
mini_riv,
bitmask,
ret_cta,
mom,
epsilon,
my_data,
pair_data,
pair_data2,
pair_data3,
bn_group,
magic,
bwd_occup,
bwd_grid_x,
multi_stream,
)
return (
dx,
dz,
dscale,
dbias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class BatchNorm2d_NHWC(_BatchNorm):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def __init__(
self,
num_features,
fuse_relu=False,
bn_group=1,
max_cta_per_sm=2,
cta_launch_margin=12,
multi_stream=False,
):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.multi_stream = multi_stream
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
# defaut to distributed bn disabled
self.bn_group = bn_group
self.max_cta_per_sm = max_cta_per_sm # used only in training fwd and bwd
self.cta_launch_margin = cta_launch_margin # used only in training fwd and bwd
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.pair_data3 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
# calculate cta per sm occupancies
assert max_cta_per_sm > 0 # won't be able to do much with 0 CTAs :)
self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)
self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)
# calculate grid dimentions based on occupancy numbers
mp_count = torch.cuda.get_device_properties(None).multi_processor_count
self.fwd_grid_dim_x = max(mp_count * self.fwd_occupancy - cta_launch_margin, 1)
self.bwd_grid_dim_x = max(mp_count * self.bwd_occupancy - cta_launch_margin, 1)
self.addrelu_fwd_grid_dim_x = max(
mp_count * self.addrelu_fwd_occupancy - cta_launch_margin, 1
)
self.addrelu_bwd_grid_dim_x = max(
mp_count * self.addrelu_bwd_occupancy - cta_launch_margin, 1
)
self.grid_dim_y = (num_features + 63) // 64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)
# FIXME: turn pair handles into an array
if bn_group > 1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert world_size >= bn_group
assert world_size % bn_group == 0
bn_sync_steps = 1
if bn_group == 4:
bn_sync_steps = 2
if bn_group == 8:
bn_sync_steps = 3
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self.storage = self.ipc_buffer.storage()
self.share_cuda = self.storage._share_cuda_()
internal_cuda_mem = self.share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))
# internal_cuda_mem[3]: offset
my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])
handles_all = torch.empty(
world_size,
my_handle.size(0),
dtype=my_handle.dtype,
device=my_handle.device,
)
handles_l = list(handles_all.unbind(0))
torch.distributed.all_gather(handles_l, my_handle)
offsets_all = torch.empty(
world_size,
my_offset.size(0),
dtype=my_offset.dtype,
device=my_offset.device,
)
offsets_l = list(offsets_all.unbind(0))
torch.distributed.all_gather(offsets_l, my_offset)
# whom do I actually care about? that would be local_rank XOR 1
self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()
pair_offset = offsets_l[local_rank ^ 1].cpu()
self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)
if bn_group > 2:
self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
if bn_group > 4:
self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 4].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)
# FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
def forward(self, x, z=None):
if z is not None:
assert self.fuse_relu == True
return bn_addrelu_NHWC_impl.apply(
x,
z,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.minibatch_mean,
self.minibatch_riv,
self.grid_dim_y,
self.ret_cta,
self.momentum,
self.eps,
self.training,
self.bn_group,
self.my_data,
self.pair_data,
(self.magic),
self.pair_data2,
self.pair_data3,
self.addrelu_fwd_occupancy,
self.addrelu_fwd_grid_dim_x,
self.addrelu_bwd_occupancy,
self.addrelu_bwd_grid_dim_x,
self.multi_stream,
)
else:
return bn_NHWC_impl.apply(
x,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.minibatch_mean,
self.minibatch_riv,
self.ret_cta,
self.momentum,
self.eps,
self.fuse_relu,
self.training,
self.bn_group,
self.my_data,
self.pair_data,
(self.magic),
self.pair_data2,
self.pair_data3,
self.fwd_occupancy,
self.fwd_grid_dim_x,
self.bwd_occupancy,
self.bwd_grid_dim_x,
self.multi_stream,
)
def __del__(self):
if self.bn_group > 1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group > 2:
bnp.close_remote_data(self.pair_handle2)
if self.bn_group > 4:
bnp.close_remote_data(self.pair_handle3)