Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Aug 26, 2025

Description

This PR introduces vocabulary tiling, a new feature for MaxText designed to significantly reduce peak memory consumption during training.

This optimization is particularly beneficial for two key scenarios:

  1. Training large-vocabulary models on limited devices: It makes it possible to train models like Gemma 2B and Gemma-2 2B, which have large vocabulary sizes, on hardware with limited memory.
  2. Training with long sequence lengths: It makes training on long context lengths more feasible by avoiding the need to store the entire, large logits tensor in memory for the final loss computation.

The core idea of vocabulary tiling is to avoid explicitly materializing the full final logits tensor. Instead, the logits activation is chunked (or "tiled") along the batch-sequence dimension.

As illustrated in the diagram below, the forward and backward passes are repeated num_vocab_tiling times. In each iteration, a small slice of the logits is computed, used to calculate the loss, and the Vector-Jacobian product is immediately backpropagated. This iterative process avoids holding the complete, memory-intensive logits tensor.

Illustration of the vocabulary tiling process, showing chunked logits with repeated forward and backward passes.
Figure 1: The vocabulary tiling process. The forward and backward passes are repeated for each tile, preventing the full logits tensor from being stored in memory.

For a more in-depth technical explanation, please see the design document.

Doc: go/maxtext-vocab-tiling
FIXES: b/429255841

Tests

Correctness Tests

Test losses and embedding table gradient differences in MaxText/tests/vocab_tiling_test.py with 1% relative error tolerance in following cases:

  • Default non-tied embedding (logits_via_embedding=False).
  • Tied embedding and default sharding (FSDP).
  • Data parallelism.
  • Tensor parallelism.
  • Context parallelism.

Performance Tests

See doc.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@NuojCheng NuojCheng force-pushed the chengnuojin-seq-tiling branch 6 times, most recently from b9ad751 to c1277e3 Compare August 29, 2025 22:37
@NuojCheng NuojCheng changed the title [Draft] Add sequence tiling to reduce redundant memory [Draft] Add vocabulary tiling to reduce redundant memory Sep 4, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-seq-tiling branch 9 times, most recently from f231602 to 6eb981c Compare September 8, 2025 03:52
@NuojCheng NuojCheng marked this pull request as ready for review September 8, 2025 03:54
@github-actions
Copy link

🤖 Hi @gobbleturk, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@NuojCheng NuojCheng force-pushed the chengnuojin-seq-tiling branch from bf90996 to 9cfd8ef Compare September 25, 2025 20:18
@@ -297,6 +311,19 @@ def no_op(self, *args, **kwargs):
def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32):
return True

def logits_from_hidden_states(self, hidden_states, deterministic):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused - is this function defined twice? here and above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will only keep one of them, Sorry for the confusion.

@@ -410,6 +441,20 @@ class ZeroOneTransformer(nn.Module):
def setup(self):
self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode)

def logits_from_hidden_states(self, hidden_states, deterministic):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

three times?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove them. This is because there are three types of Transformers defined in model.py but actually only one of them get used.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Nuojin!

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces vocabulary tiling, a memory-saving optimization that computes the cross-entropy loss in chunks to reduce peak memory usage. The implementation is well-structured, with the core logic encapsulated in maxtext_utils.py and comprehensive tests to ensure correctness across various sharding configurations.

🔍 General Feedback

  • The addition of thorough unit tests for different parallelism strategies is excellent and ensures the reliability of this new feature.
  • The code is clean and the changes are well-integrated into the existing structure.
  • The TODOs for future optimizations are noted and will be important for maximizing the benefits of this feature.

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL at Gemini's nits but otherwise LGTM!

@NuojCheng NuojCheng force-pushed the chengnuojin-seq-tiling branch from 82250c2 to ea86e5a Compare September 26, 2025 17:01
@NuojCheng NuojCheng force-pushed the chengnuojin-seq-tiling branch from ea86e5a to 7775ebf Compare September 26, 2025 17:23
@copybara-service copybara-service bot merged commit e9266c8 into main Sep 26, 2025
27 checks passed
@copybara-service copybara-service bot deleted the chengnuojin-seq-tiling branch September 26, 2025 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants