-
Notifications
You must be signed in to change notification settings - Fork 444
Add vocabulary tiling to reduce redundant memory #2242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b9ad751 to
c1277e3
Compare
f231602 to
6eb981c
Compare
|
🤖 Hi @gobbleturk, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
bf90996 to
9cfd8ef
Compare
src/MaxText/layers/models.py
Outdated
| @@ -297,6 +311,19 @@ def no_op(self, *args, **kwargs): | |||
| def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): | |||
| return True | |||
|
|
|||
| def logits_from_hidden_states(self, hidden_states, deterministic): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm confused - is this function defined twice? here and above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will only keep one of them, Sorry for the confusion.
src/MaxText/layers/models.py
Outdated
| @@ -410,6 +441,20 @@ class ZeroOneTransformer(nn.Module): | |||
| def setup(self): | |||
| self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode) | |||
|
|
|||
| def logits_from_hidden_states(self, hidden_states, deterministic): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
three times?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove them. This is because there are three types of Transformers defined in model.py but actually only one of them get used.
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nuojin!
9cfd8ef to
82250c2
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces vocabulary tiling, a memory-saving optimization that computes the cross-entropy loss in chunks to reduce peak memory usage. The implementation is well-structured, with the core logic encapsulated in maxtext_utils.py and comprehensive tests to ensure correctness across various sharding configurations.
🔍 General Feedback
- The addition of thorough unit tests for different parallelism strategies is excellent and ensures the reliability of this new feature.
- The code is clean and the changes are well-integrated into the existing structure.
- The
TODOs for future optimizations are noted and will be important for maximizing the benefits of this feature.
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PTAL at Gemini's nits but otherwise LGTM!
82250c2 to
ea86e5a
Compare
ea86e5a to
7775ebf
Compare
Description
This PR introduces vocabulary tiling, a new feature for MaxText designed to significantly reduce peak memory consumption during training.
This optimization is particularly beneficial for two key scenarios:
The core idea of vocabulary tiling is to avoid explicitly materializing the full final logits tensor. Instead, the logits activation is chunked (or "tiled") along the batch-sequence dimension.
As illustrated in the diagram below, the forward and backward passes are repeated
num_vocab_tilingtimes. In each iteration, a small slice of the logits is computed, used to calculate the loss, and the Vector-Jacobian product is immediately backpropagated. This iterative process avoids holding the complete, memory-intensive logits tensor.Figure 1: The vocabulary tiling process. The forward and backward passes are repeated for each tile, preventing the full logits tensor from being stored in memory.
For a more in-depth technical explanation, please see the design document.
Doc: go/maxtext-vocab-tiling
FIXES: b/429255841
Tests
Correctness Tests
Test losses and embedding table gradient differences in
MaxText/tests/vocab_tiling_test.pywith 1% relative error tolerance in following cases:logits_via_embedding=False).Performance Tests
See doc.
Checklist
Before submitting this PR, please make sure (put X in square brackets):