-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Hi @jenkspt, I have been learning nanoGPT and reproducing it in JAX from scratch. Your repo has been a very helpful reference.
I encountered an issue with optax.apply_every and thought you might want to know. It turns out optax.apply_every is not equivalent to how nanoGPT updates, which is to accumulate gradient and then clip by grad norm.
See this snippet
optimizer = optax.chain(
optax.clip_by_global_norm(0.2),
optax.sgd(1e-4),
)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)
new_params_single_batch = fit(
optimizer,
params,
batches=[
MiniBatch(image=EXAMPLES, label=LABELS),
],
)
new_params_gradient_accumulation = fit(
optax.MultiSteps(optimizer, every_k_schedule=3),
params,
batches=[
MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
],
)
chex.assert_trees_all_close(
new_params_single_batch,
new_params_gradient_accumulation,
atol=1e-7,
)
new_params_gradient_accumulation2 = fit(
optax.chain(
optax.clip_by_global_norm(0.2),
optax.sgd(1e-4),
optax.apply_every(3),
),
params,
batches=[
MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
],
)
chex.assert_trees_all_close(
new_params_single_batch,
new_params_gradient_accumulation2,
atol=1e-7,
)checking equivalence of single batch and optax.MultiSteps
checking equivalence of single batch and optax.apply_every
Traceback (most recent call last):
File "/home/costa/Documents/go/src/github.com/vwxyzjn/envpool-xla-cleanrl/optax_grad_accu_clip.py", line 109, in <module>
chex.assert_trees_all_close(
File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 197, in _chex_assert_fn
host_assertion(*args, **kwargs)
File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 157, in _static_assert
raise exception_type(error_msg)
AssertionError: [Chex] Assertion assert_trees_all_close failed: Trees 0 and 1 differ in leaves 'mlp/~/linear_0/b':
Not equal to tolerance rtol=1e-06, atol=1e-07
Error in value equality check: Values not approximately equal
Mismatched elements: 12 / 32 (37.5%)
Max absolute difference: 3.6430913e-07
Max relative difference: 1.4173905
x: array([-7.326746e-08, 0.000000e+00, 1.287373e-07, 0.000000e+00,
0.000000e+00, -1.868993e-07, -1.627599e-07, -7.771037e-08,
-6.023089e-07, 1.824948e-08, -9.969744e-08, 0.000000e+00,...
y: array([-1.193194e-07, 0.000000e+00, 1.980013e-07, 0.000000e+00,
0.000000e+00, -3.954974e-07, -2.911101e-07, -1.588895e-07,
-9.666180e-07, -4.372281e-08, -2.566443e-07, 0.000000e+00,...
Original dtypes: float32, float32.
Empirically, it could have a significant impact on training as well. I am following nanoGPT's setting in a single GPU, which is accumulate gradient 40 times. As shown below, optax.apply_every is significantly more unstable than optax.MultiStep.
While the training might be more stable if the gradient accumulation steps are fewer, it still feels like an issue...
jenkspt
Metadata
Metadata
Assignees
Labels
No labels