Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Jan 15, 2026

Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch due to split_rngs being replicated.

…d to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.

PiperOrigin-RevId: 856253490
@copybara-service copybara-service bot changed the title Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch due to split_rngs being replicated. Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant