diff --git a/apex/amp/handle.py b/apex/amp/handle.py index 41617deee..d7409dc8e 100644 --- a/apex/amp/handle.py +++ b/apex/amp/handle.py @@ -118,8 +118,8 @@ def scale_loss(loss, if should_skip: optimizer_step = optimizer.step def skip_step(): - maybe_print("Gradient overflow. Skipping step, loss scaler {} reducing " + - "loss scale to {}".format(loss_id, loss_scaler.loss_scale())) + maybe_print(("Gradient overflow. Skipping step, loss scaler {} reducing " + + "loss scale to {}").format(loss_id, loss_scaler.loss_scale())) optimizer.step = optimizer_step optimizer.step = skip_step