Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3f87614

Browse files
author
mcarilli
authored
WIP: Handle arbitrary combinations of optimizers/models/losses (NVIDIA#232)
* Refactor to allow more flexible treatment of multiple optimizers/models/losses * Adding _process_optimizers.py * Created L0 tests (now passing). * fix: minor print typo (NVIDIA#234) * make L1 results easier to read * L0 multiple model/optimizer/loss test fleshed out * Adding test that master params remain synced across distributed processes * Docstring updates * Docstring updates
1 parent 214fda4 commit 3f87614

22 files changed

Lines changed: 1611 additions & 216 deletions

apex/amp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
register_half_function, register_float_function, register_promote_function
33
from .handle import scale_loss, disable_casts
44
from .frontend import initialize
5-
from ._amp_state import master_params
5+
from ._amp_state import master_params, _amp_state

apex/amp/_amp_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class AmpState(object):
1818
def __init__(self):
1919
self.hard_override=False
20+
self.allow_incoming_model_not_fp32 = False
2021
self.verbosity=1
2122

2223

apex/amp/_initialize.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._amp_state import _amp_state, warn_or_err, container_abcs
77
from .handle import disable_casts
88
from .scaler import LossScaler
9+
from ._process_optimizer import _process_optimizer
910
from apex.fp16_utils import convert_network
1011
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
1112
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
@@ -122,7 +123,7 @@ def wrap_fused_adam(optimizer, properties):
122123
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
123124

124125

125-
def _initialize(models, optimizers, properties):
126+
def _initialize(models, optimizers, properties, num_losses=1):
126127
from apex.parallel import DistributedDataParallel as apex_DDP
127128
from .amp import init as amp_init
128129

@@ -146,7 +147,8 @@ def _initialize(models, optimizers, properties):
146147

147148
check_models(models)
148149

149-
check_params_fp32(models)
150+
if not _amp_state.allow_incoming_model_not_fp32:
151+
check_params_fp32(models)
150152

151153
check_optimizers(optimizers)
152154

@@ -181,21 +183,16 @@ def new_fwd(*args, **kwargs):
181183
for optimizer in optimizers:
182184
optimizer.load_state_dict(optimizer.state_dict())
183185

184-
if properties.master_weights:
185-
for i, optimizer in enumerate(optimizers):
186-
if isinstance(optimizer, FusedAdam):
187-
optimizers[i] = wrap_fused_adam(optimizer, properties)
188-
if properties.loss_scale == "dynamic":
189-
optimizers[i] = FP16_Optimizer_general(optimizer,
190-
dynamic_loss_scale=True,
191-
verbose=False)
192-
else:
193-
optimizers[i] = FP16_Optimizer_general(optimizer,
194-
static_loss_scale=properties.loss_scale,
195-
verbose=False)
196-
else:
197-
for optimizer in optimizers:
198-
optimizer.loss_scaler = LossScaler(properties.loss_scale)
186+
for i, optimizer in enumerate(optimizers):
187+
# Still need to special case this for the first pass
188+
if isinstance(optimizer, FusedAdam):
189+
optimizers[i] = wrap_fused_adam(optimizer, properties)
190+
else:
191+
optimizers[i] = _process_optimizer(optimizer, properties)
192+
193+
_amp_state.loss_scalers = []
194+
for _ in range(num_losses):
195+
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale))
199196

200197
if properties.patch_torch_functions:
201198
# handle is unused here. It's accessible later through a global value anyway.

0 commit comments

Comments
 (0)