-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathasp.py
More file actions
472 lines (432 loc) · 21.1 KB
/
asp.py
File metadata and controls
472 lines (432 loc) · 21.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
import types
import torch
from .sparse_masklib import create_mask
from .permutation_lib import Permutation
torchvision_imported = True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported = False
import os
import time
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:
if allowed_layer_names is not None and name not in allowed_layer_names:
continue
eligible_modules_list.append((name, mod))
return eligible_modules_list
class ASP:
__model = None
__verbosity = 0
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ""
@classmethod
def init_model_for_pruning(
cls,
model,
mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.MultiheadAttention,
],
allowed_layer_names=None,
disallowed_layer_names=[],
allow_recompute_mask=False,
custom_layer_dict={},
allow_permutation=True,
):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
If you are starting with a fresh model:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.
If you are starting from a checkpoint:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
torch.load(...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
Arguments:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
"""
assert cls.__model is None, "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator # user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print(
"[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision."
)
torchvision_version = str(torchvision.__version__)
torchvision_version_major = int(torchvision_version.split(".")[0])
torchvision_version_minor = int(torchvision_version.split(".")[1])
if torchvision_version_major == 0 and torchvision_version_minor < 12:
sparse_parameter_list = {
torch.nn.Linear: ["weight"],
torch.nn.Conv1d: ["weight"],
torch.nn.Conv2d: ["weight"],
torch.nn.Conv3d: ["weight"],
torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ["weight"],
torch.nn.MultiheadAttention: [
"q_proj_weight",
"k_proj_weight",
"v_proj_weight",
"in_proj_weight",
],
torchvision.ops.misc.Conv2d: ["weight"],
}
else: # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removed
sparse_parameter_list = {
torch.nn.Linear: ["weight"],
torch.nn.Conv1d: ["weight"],
torch.nn.Conv2d: ["weight"],
torch.nn.Conv3d: ["weight"],
torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ["weight"],
torch.nn.MultiheadAttention: [
"q_proj_weight",
"k_proj_weight",
"v_proj_weight",
"in_proj_weight",
],
}
else:
sparse_parameter_list = {
torch.nn.Linear: ["weight"],
torch.nn.Conv1d: ["weight"],
torch.nn.Conv2d: ["weight"],
torch.nn.Conv3d: ["weight"],
torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ["weight"],
torch.nn.MultiheadAttention: [
"q_proj_weight",
"k_proj_weight",
"v_proj_weight",
"in_proj_weight",
],
}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())
for module_type in whitelist:
assert module_type in sparse_parameter_list, (
"Module %s :: Don't know how to sparsify module." % module.dtype()
)
# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
for p_name, p in module.named_parameters():
if p_name in sparse_parameters and p.requires_grad:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if p.dtype == torch.float32 and (
(p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
): # User defines FP32 and APEX internally uses FP16 math
print(
"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
% (module_name, p_name, str(p.size()), str(p.dtype))
)
continue
if p.dtype == torch.float16 and (
(p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
): # For Conv2d dim= K x CRS; we prune along C
print(
"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
% (module_name, p_name, str(p.size()), str(p.dtype))
)
continue
if cls.__verbosity >= 3:
print(
"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity"
% (module_name, p_name, str(p.size()), str(p.dtype))
)
mask = torch.ones_like(p).bool()
buffname = p_name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer("__%s_mma_mask" % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
module.register_buffer("__%s_mma_pruned_p" % buffname, pruned)
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
if cls.__verbosity >= 3:
print(
"[ASP] Not sparsifying %s::%s of size=%s and type=%s"
% (module_name, p_name, str(p.size()), str(p.dtype))
)
for name, sparse_module in eligible_modules(
model, tuple(whitelist), allowed_layer_names, disallowed_layer_names
):
add_sparse_attributes(name, sparse_module)
if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("'")[1]
if (
module_type_str == "torch.nn.modules.container.Sequential"
or module_type_str.startswith("torchvision.models")
):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == "torch.nn.modules.batchnorm.BatchNorm2d":
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + ".running_mean"
module_var_name = module_name + ".running_var"
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append(
(
module_name,
module,
param_key.split(".")[-1],
model.state_dict()[param_key],
)
)
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = "."
# Set the corresponding params from ASP class to the Permutation class
permutation_verbosity = 5
Permutation.set_permutation_params_from_asp(
cls.__model,
cls.__sparse_parameters,
cls.__all_parameters,
permutation_verbosity,
)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()
@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already."""
if cls.__model is None:
if cls.__verbosity >= 3:
print("[ASP] ASP has not been initialized.")
return False
else:
if cls.__verbosity >= 3:
print("[ASP] ASP has been initialized already.")
return True
@classmethod
def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
gradients and weights during training.
You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)
"""
assert cls.__optimizer is None, "ASP has initialized optimizer already."
assert cls.__calculate_mask is not None, (
"Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
)
# store pointer to original optimizer step method
cls.__optimizer = optimizer
cls.__optimizer.__step = optimizer.step
def __step(opt_self, *args, **kwargs):
# prune gradients before step method
with torch.no_grad():
for (
module_name,
module,
p_name,
p,
mask,
pruned,
) in cls.__sparse_parameters:
if p.grad is not None: # thx pjudd
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
with torch.no_grad():
for (
module_name,
module,
p_name,
p,
mask,
pruned,
) in cls.__sparse_parameters:
p.mul_(mask)
return rval
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)
@classmethod
def compute_sparse_masks(cls):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_permute = time.perf_counter()
successful_permutation = False
try:
successful_permutation = Permutation.permute_model(
cls.__model.module,
dump_fx_graph=cls.__save_permutation_graph,
save_dumped_fx_graph=os.path.join(
cls.__permutation_output_dir,
"model_offline_permutation_graph.json",
),
)
if successful_permutation:
print("\n[compute_sparse_masks] permuted the (distributed) model.")
except AttributeError:
successful_permutation = Permutation.permute_model(
cls.__model,
dump_fx_graph=cls.__save_permutation_graph,
save_dumped_fx_graph=os.path.join(
cls.__permutation_output_dir,
"model_offline_permutation_graph.json",
),
)
if successful_permutation:
print("\n[compute_sparse_masks] permuted the model.")
if successful_permutation:
duration_build_offline_permutation_graph = (
time.perf_counter() - start_time_permute
)
print(
"[compute_sparse_masks] Take {:.4f} seconds to find and apply permutations.".format(
duration_build_offline_permutation_graph
)
)
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert pruned is not None, (
"Unable to restore dense parameter because allow_recompute_mask == False"
)
p.add_(pruned.cuda())
mask.set_(cls.__calculate_mask(p))
if pruned is not None: # stow away pruned weights to cpu
pruned.set_((p * (~mask)).cpu())
p.mul_(
mask
) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print(
"[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s with magnitude %s"
% (
100.0 - 100.0 * mask.sum() / mask.numel(),
module_name,
p_name,
str(p.size()),
str(p.dtype),
torch.sum(torch.abs(p)),
)
)
@classmethod
def restore_pruned_weights(cls):
"""Call this method to disable sparsity and restore all weights.
This will only work if init(...) was called with allow_recompute=True.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel():
assert pruned is not None, (
"Unable to restore dense parameter because allow_recompute_mask == False"
)
p.add_(pruned.cuda())
mask.fill_(1)
pruned.zero_()
if cls.__verbosity >= 2:
print(
"[ASP] Disabled sparsity for %s::%s (dense weights restored)"
% (module_name, p_name)
)
@classmethod
def is_sparsity_enabled(cls):
"""Call this method to determine if sparsity is enabled in the model.
The typical use case is right after checkpoint has been loaded.
"""
total, sp100, sp50 = 0, 0, 0
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
total += 1
mask_sum = mask.sum()
mask_numel = mask.numel()
if mask_sum == mask_numel:
sp100 += 1
elif mask_sum * 2 == mask_numel:
sp50 += 1
assert total == sp100 or total == sp50, "Inconsistent model sparsity"
if total == sp100:
return False
elif total == sp50:
return True
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(
model,
mask_calculator="m4n2_1d",
verbosity=2,
whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention],
allow_recompute_mask=False,
)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
@classmethod
def set_permutation_saving_params(
cls,
allow_permutation=True,
save_permutation_graph=False,
permutation_output_dir=".",
):
"""This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class."""
print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters")
print("\n[set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print(
"[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation)
)
cls.__save_permutation_graph = save_permutation_graph
print(
"[set_permutation_saving_param]\t Save permutation graphs: {}".format(
cls.__save_permutation_graph
)
)
cls.__permutation_output_dir = permutation_output_dir
print(
"[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(
cls.__permutation_output_dir
)
)
Permutation.set_permutation_saving_params(
allow_permutation, save_permutation_graph, permutation_output_dir
)