Adams is a next-generation optimizer that blends the simplicity of element-wise methods with the stability benefits of matrix-aware regularization. It updates neural network parameters in both 1D (per-element) and 2D (per-matrix) ways, staying fast and easy to parallelize like Adam/AdamW while improving stability and generalization.
- Stable: no loss spikes observed; no gradient clipping required.
- Fast & scalable: element-wise updates + one rank-1 spectral decay step per matrix; easily parallelizable.
- Simple: no
epsilonhyperparameter; truly scale-invariant per-parameter update.
Small second-moment estimates are a major source of instability and loss spikes in Adam-like methods. Adams replaces the usual preconditioned step with a bounded update using atan2:
which:
- naturally bounds the step size,
- removes the need for the
epsilonhyperparameter, - yields true scale invariance of the update.
For matrix parameters
- We compute a one-step power iteration with persistent state (same idea as PyTorch’s
spectral_norm) to approximate the top singular triplet$(u_1, \sigma_1, v_1)$ . - The decay term is applied as
$\sqrt{M} u_1 \sigma_1 v^\top_1$ (the gradient of$\tfrac{1}{2}\sigma_1^2$ , scaled by$\sqrt{M}$ to match the RMS of$W$ ) per update step. - This helps control activation scales and mitigates instabilities tied to large spectral norms.
Efficiency: the spectral step adds only two GEMV operations per matrix per update, comparable to a handful of extra element-wise ops. In typical FSDP/ZeRO setups the full weight matrix is available during forward/backward, so this integrates cleanly at scale.
Recent reports suggest that fully matrix-based optimizers (e.g., Muon) can be hard to implement/parallelize broadly and often show modest end-to-end benefits on large models (~1.1x or less), despite strong stability. Meanwhile, the dominant optimizer Adam is simple and fast but prone to instability and loss spikes.
Adams asks: Can we keep Adam’s speed and simplicity while gaining matrix-level stability?
pip install adams-torchYou don’t need to manually broadcast parameters or all-reduce gradients—multi-GPU usage matches single-GPU usage. Fully compatible with torch.compile.
FSDP is not supported yet. Contributions welcome.
import os
import torch
import torch.distributed as dist
from adams import Adams_ZeRO # main optimizer
def init():
# Initialize distributed training if launched via torchrun/torch.distributed
if "LOCAL_RANK" in os.environ:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
# Spectral decay applies to matrix-shaped params.
# scalar_vector_weight_decay applies standard decoupled L2 to 0D/1D params.
optimizer = Adams_ZeRO(
model.parameters(),
lr=3e-4,
weight_decay=0.1, # spectral decay for matrices
scalar_vector_weight_decay=0.1, # L2 for scalars/vectors
betas=(0.9, 0.95)
)
# Sync any internal buffers across ranks if required by your setup.
optimizer.broadcast_buffers(model.buffers())
return model, optimizer
@torch.compile # Optional: works with torch.compile
def train_step(model, optimizer, batch):
loss = model(batch) # forward; compute your loss
loss.backward() # backward
optimizer.step() # no gradient clipping needed
optimizer.zero_grad(set_to_none=True)
return lossCare should be taken as matrix-based optimizers (e.g. Muon).
- Non‑matrix parameters. Disable the matrix‑based part (spectral decay) for parameters that are scalars, vectors, or collections of vectors (e.g. LayerNorm, Embedding, Output Head, etc.) by setting
param.use_spectral_decay = False. Adams uses a separate decoupled L2 term, controlled byscalar_vector_weight_decay(default0.1). - Batched matrices. Parameters that are conceptually multiple matrices concatenated along leading dimensions (e.g., attention QKV projections) should be expressed with shape
(B, M, N). Adams treats all dimensions except the last two as batch dimensions. (P.S. In our experiments, we treat each Attention head q,k,v as separate projection matrices. E.g. there are 24 (8 head * 3) matrices in QKV proj for 8 MHA heads)
- Hyperparameters: start with AdamW-like settings; the bounded update removes
epsilon. Adams can handle much larger weight decay to improve generalization, e.g.1.0. - Stability: the bounded step and spectral decay together target sources of spikes linked to tiny second moments and large spectral norms.
- Generalization & adversarial robustness: spectral regularization is widely observed to improve both, and Adams adopts a lightweight decoupled form.
- Scaling Exponents Across Parameterizations and Optimizers
- Adaptive Preconditioners Trigger Loss Spikes in Adam
- Muon: An optimizer for the hidden layers of neural networks
- Spectral Norm Regularization for Improving the Generalizability of Deep Learning
- Thinking from spectral norm gradient to new weight decay
Apache-2.0