Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Unexpected behavior of optax.apply_every #2

@vwxyzjn

Description

@vwxyzjn

Hi @jenkspt, I have been learning nanoGPT and reproducing it in JAX from scratch. Your repo has been a very helpful reference.

I encountered an issue with optax.apply_every and thought you might want to know. It turns out optax.apply_every is not equivalent to how nanoGPT updates, which is to accumulate gradient and then clip by grad norm.

See this snippet

optimizer = optax.chain(
    optax.clip_by_global_norm(0.2),
    optax.sgd(1e-4),
)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)


new_params_gradient_accumulation2 = fit(
    optax.chain(
        optax.clip_by_global_norm(0.2),
        optax.sgd(1e-4),
        optax.apply_every(3),
    ),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation2,
    atol=1e-7,
)
checking equivalence of single batch and optax.MultiSteps
checking equivalence of single batch and optax.apply_every
Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/envpool-xla-cleanrl/optax_grad_accu_clip.py", line 109, in <module>
    chex.assert_trees_all_close(
  File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 197, in _chex_assert_fn
    host_assertion(*args, **kwargs)
  File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 157, in _static_assert
    raise exception_type(error_msg)
AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees 0 and 1 differ in leaves 'mlp/~/linear_0/b': 
Not equal to tolerance rtol=1e-06, atol=1e-07
Error in value equality check: Values not approximately equal
Mismatched elements: 12 / 32 (37.5%)
Max absolute difference: 3.6430913e-07
Max relative difference: 1.4173905
 x: array([-7.326746e-08,  0.000000e+00,  1.287373e-07,  0.000000e+00,
        0.000000e+00, -1.868993e-07, -1.627599e-07, -7.771037e-08,
       -6.023089e-07,  1.824948e-08, -9.969744e-08,  0.000000e+00,...
 y: array([-1.193194e-07,  0.000000e+00,  1.980013e-07,  0.000000e+00,
        0.000000e+00, -3.954974e-07, -2.911101e-07, -1.588895e-07,
       -9.666180e-07, -4.372281e-08, -2.566443e-07,  0.000000e+00,... 
Original dtypes: float32, float32.

Empirically, it could have a significant impact on training as well. I am following nanoGPT's setting in a single GPU, which is accumulate gradient 40 times. As shown below, optax.apply_every is significantly more unstable than optax.MultiStep.

image

While the training might be more stable if the gradient accumulation steps are fewer, it still feels like an issue...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions