What is Momentum? Is there a clearer way for us to reason about the popular stochastic gradient methods that dominate deep learning?
Enter the AutoSGM framework .
-
Momentum in (Polyak's Heavy Ball, PHB and Nesterov's Accelerated Gradient, NAG) are points in the design space of a first-order lowpass filter
-
Moment estimation ( that is Adaptive Moment Estimation Adam) of the gradient's moment and Partial-correlation estimators are part of optimal iteration-dependent learning rate functions.
-
Smoothing the gradient via the first-order lowpass filter is approximately smoothing the loss function, suggesting that its primary function is regularization not acceleration.
-
In general, under a first-order filtering of the gradient, and choice of an iteration-dependent learning rate, the stochastic gradient learning dynamics is that of a first-order linear time (iteration) varying (LTV) filter with respect to the parameter-change or step level.
-
Characterize the stability properties of the stochastic gradient learning dynamics.
-
Both standard weight-decay and decoupled weight-decay are coupled damping terms in the parameter-change level.
Note
Parameter-change level is simply the path the stochastic gradient algorithm takes the actual parameter update through during learning.
This AutoSGM implementation allows researchers and developers to play with several variations of the stochastic gradient algorithm, as they have been called, such as Polyak’s Heavy Ball (PHB), Nesterov’s Accelerated Gradient (NAG), and Adam.
By treating these methods as points in a design space of first-order filters and iteration-dependent learning rates, this framework makes it easier to explore how changes in smoothing, and learning rates affect training stability and performance.
Important
Common configurations are explained in Setups.
import math
from asgm import AutoSGM
# lr_cfg - setup learning-rate (lr) algorithm
# beta_cfg - setup filters: averaging, and grad smoothing
# rc_cfg - setup window (lr schedule)
# eps_cfg - setup numerical eps
# wd_cfg - setup weight decay
# training length
num_epochs = 1
iters_per_epoch = int(1E9)
# [gradient smoothing] lowpass filter's pole
beta = 0 # off
# Instance: Plain SGM, constant lr 1E-3, coupled weight-decay with wd constant 1E-4
opt0 = AutoSGM(
model.parameters(),
lr_cfg=(False, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, 0.0, 0, True)
wd_cfg=(1e-4, 0),
eps_cfg=(1e-8, )
)
# We decide to:
# 1. use moment estimation as part of the lr,
# The rest of the configuration remains the same, except in lr_cfg.
# RMSProp
# Instance: Plain SGM, moment estimation, coupled weight-decay
opt1 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, 0.0, 0, True)
rc_cfg=(1, 0, 0, 2, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-4, 0),
eps_cfg=(1e-8, )
)
# 2. activate windowing, via cosine annealing as the lr schedule,
# The rest of the configuration remains the same, but we now need to setup rc_cfg.
# Instance: Plain SGM, moment estimation, standard cosine annealing, coupled weight-decay
opt2 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, 0.0, 0, True)
rc_cfg=(1, 0, 0, 2, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-4, 0),
eps_cfg=(1e-8, )
)
# 3. activate gradient smoothing,
# The rest of the configuration remains the same, except in beta_cfg.
beta = 0.9 # activated with 0 < beta << 1
# 4. switch to decoupled weight-decay,
# The rest of the configuration remains the same, except in wd_cfg
# we now use a wd constant of 1E-2, and the 0 is replaced with 1
# Instance: PHB, moment estimation, standard cosine annealing, decoupled weight-decay
opt3 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, 0.0, 0, True)
rc_cfg=(1, 0, 0, 2, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-2, 1),
eps_cfg=(1e-8, )
)
# 5. use NAG: adjust the gradient smoothing filter's zero location from 0 to beta / (1 + beta)
# The rest of the configuration remains the same, except in beta_cfg.
gamma = beta/(1+beta) # gamma ~ 0.47 for beta = 0.9
# Instance: NAG, moment estimation, standard cosine annealing, decoupled weight-decay
opt4 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, gamma, 0, True),
rc_cfg=(1, 0, 0, 2, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-2, 1),
eps_cfg=(1e-8,)
)
# 6. decide to use a different window as the lr schedule
# The rest of the configuration remains the same, except in rc_cfg.
# Instance: NAG, Adam, linear decay, decoupled weight-decay
opt5 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, 0),
beta_cfg=(0.9999, 0.999, beta, gamma, 0, True),
rc_cfg=(2, 0, 0, 1, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-2, 1),
eps_cfg=(1e-8,)
)
# 7. use neither PHB nor NAG: adjust the filter's zero location <= beta
# The rest of the configuration remains the same, except in beta_cfg.
gamma = 1 - math.sqrt(2*(1-beta)) # ~ 0.55
# 8. in the lr algorithm, in addition to moment estimation,
# enable partial-correlation (parcor) estimation,
# The rest of the configuration remains the same, except in lr_cfg.
optnum = 1 # or 2.
opt6 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-3, optnum),
beta_cfg=(0.9999, 0.999, beta, gamma, 0, True),
rc_cfg=(2, 0, 0, 1, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-2, 1),
eps_cfg=(1e-8,)
)
# change the lr numerator to a moment estimator,
optnum = 3
# allows an higher learning rate than using only moment estimation,
# so instead of 1e-3, we may try something close but bigger like 1e-2
# The rest of the configuration remains the same, except in lr_cfg.
opt7 = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-2, optnum),
beta_cfg=(0.9999, 0.999, beta, gamma, 0, True),
rc_cfg=(2, 0, 0, 1, 1, num_epochs, iters_epoch, 1, 0),
wd_cfg=(1e-2, 1),
eps_cfg=(1e-8,)
)
Using Adam as a fixed learning-rate numerator (colored red) baseline for the other par-cor learning rate numerator (colored blue), we tested the AutoSGM framework on CIFAR-10 image-classifcation (ViT, ResNet) and language modeling (GPT-2 on WikiText-103 and Shakespeare-char) tasks.
Results: The blue curves mostly outperformed red curves, across several zero locations (
Code and style in this repository is under active development. Feel free to raise a issue, if you detect any bug or you have any questions.
This section shows a minimal, easy-to-follow example of using the AutoSGM implementation with a PyTorch model and lists the most important configuration options.
from asgm import AutoSGM
import torch
# example model
model = torch.nn.Sequential(
torch.nn.Linear(10, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 1)
)
num_epochs = 1
iters_epoch = 100000
# Example: spawn an AutoSGM instance, with cosine annealing.
opt = AutoSGM(
model.parameters(),
lr_cfg=(True, 1e-2, 3), # setup learning-rate (lr) algorithm
beta_cfg=(0.9999, 0.999, 0.9, 0.5528, 0, True), # setup filters: averaging, and grad smoothing
rc_cfg=(1, 0, 0, 2, 1, num_epochs, iters_epoch, 1, 0), # setup window (lr schedule)
wd_cfg=(0.0, 1), # setup weight decay
eps_cfg=(1e-10,), # setup numerical eps
)
# update parameters in training loop
opt.step()
opt.zero_grad()Learning Rate (lr)
lr_cfg= (aoptlr,lr_init,num_lrc)aoptlr: bool. (True | False)- True: use the iteration-dependent ratio function of a moment estimator, and a correlation estimator as the learning rate.
- False: use
lr_initas a constant learning rate.
lr_init: float. trust-region constant used by the iteration-dependent ratio function whenaoptlr=Truenum_lrc: int. (0,1,2,3,4) select an estimator for the learning-rate numerator- 0: denominator: moment estimator, numerator: unity (baseline).
- 1: denominator: moment estimator, numerator: parcor estimator.
- 2: denominator: moment estimator, numerator: Huberized parcor estimator.
- 3: denominator: moment estimator, numerator: moment estimator.
- 4: denominator: moment estimator, numerator: unity (with Huberized prefilter).
Filtering (gradient smoothing (lowpass regularization) and exponential moving averages (EMAs))
beta_cfg= (beta_n,beta_a,beta_i,gamma_i,eta_i,debias)beta_n: float. filter pole for EMA used in correlation estimation (expects long-term memory, that is abeta_nclose to1)- default is 0.9999
beta_a: float. filter pole for EMA used in moment estimation (expects long-term memory, that is abeta_aclose to1)- default is 0.999
beta_i: float. smoothing filter's pole (0 < beta_i < 1).- default is 0.9
gamma_i: float. smoothing filter's zero (less than the pole)- HB: set as
gamma_i=0(default) - NAG: set as
gamma_i=beta_i/(1+beta_i)
- HB: set as
eta_i: float. input normalization constant for the smoothing filter- note the default
eta_i=0, the code sets it toeta_i=1-beta_i.
- note the default
debias: bool. (True | False)- True, debias the filter output for improved transient-behavior (default).
Weight-decay (Regularization)
wd_cfg= (wd_cte,wd_lvl)wd_cte: float. weight-decay constant.wd_lvl: int. decoupling level (0, 1)- 0 weight decay coupled to gradient before smoothing.
- 1 parameter-level decoupling of weight decay from gradient smoothing.
Epsilon
eps_cfg= (eps,)eps: float. small positive constant used for numerical stability (avoid division by zero).
Windowing (Trust-region variable used with the iteration-dependent ratio function, also called Learning-rate schedules)
rc_cfg= (rcm,inseq,x,n,m,tau,spe,cfact,e)rcm: int. window function (schedule type)- 0: inactive, flat,
- 1: raised-cosine
n >= 1, standard cosine annealing (n=2,m=1), recommended, smalln,n < 4
- 2: polynomial
- linear decay (
n=1,m=1) or triangular if (n=1,m=0)
- linear decay (
- 3: exponential (parameterized),
- 4: simple-polynomial
n >=1, recommended1 < n < 20
- 5: logistic-sigmoid, mid-inflection
n > 0, recommended, smalln, less than 10
- 6: logistic-sigmoid, late-inflection
n > 0, recommended, smalln, less than 10
inseq: int. input sequence type- 0: uniform/rectangular,
- 1: kronecker/randomized ordering.
x: float. minimum fraction of unity (default is 0).n: shaping parameter for the window functionrcm- order. int | float if
rcmin {1, 2, 4, 5, 6}, - decay rate. float if
rcmin {3, }.
- order. int | float if
m: int. half window shape or full window envelope.- 1: half, means: flat-top -> anneal,
- 0: full, means: warm-up -> flat-top -> anneal.
tau >= 1: int. number of epochs (used in computing window length depending oncfact).tau=1means one epoch.taushould be set as1if training is only in iterations.
spe: int. steps per epoch (iterations in one epoch).- For instance:
tau=1,spe=10000means 10000 iterations in one epoch
- For instance:
cfact: int. (0,1,2,3,4). window's step unit- 0: step unit is per-epoch,
- 1: step unit is per-iteration,
- 2: step unit is per-iteration in each epoch (resets at each epoch).
e: float. flat-top coverage fraction of the window,0 <= e < 1.- Default is 0.
- For instance:
e=0.5, andm=1means first-half of the window is flat-top constant, and then in the other-half of the window annealing is observed.
If something doesn't behave as you expect, open an issue in the repository.