-
Notifications
You must be signed in to change notification settings - Fork 22
[train] refactor training #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Warning Rate limit exceeded@xingchensong has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 6 minutes and 25 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
WalkthroughThe changes centralize model training specification and configuration registration by moving them from individual model submodules, such as Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Trainer
participant Model
participant Tokenizer
participant TrainSpecRegistry
User->>Trainer: Initialize Trainer
Trainer->>TrainSpecRegistry: Retrieve TrainSpec for model
Trainer->>Tokenizer: Build tokenizer (conditionally pass special tokens)
Trainer->>Model: Build and parallelize model
Model->>Trainer: Return model (converted to float32)
Trainer->>User: Ready for training steps
User->>Trainer: Train step with batch
Trainer->>Trainer: Move batch tensors to device
Trainer->>Model: Forward pass with unpacked batch
Model->>Trainer: Return outputs
Trainer->>User: Compute loss, accuracy, and metrics
Possibly related PRs
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
touchnet/models/llama/__init__.py (1)
50-60: Add docstring and consider improving readability.The function implementation is correct, but it needs documentation and could be slightly more readable.
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: + """Calculate the total number of parameters in a model. + + Args: + model: The PyTorch model to analyze. + exclude_embedding: If True, exclude embedding layer parameters from the count. + + Returns: + The total number of parameters. + """ num_params = sum(p.numel() for p in model.parameters()) if exclude_embedding: base_model_prefix = getattr(model, "base_model_prefix", "model") submodel = getattr(model, f"{base_model_prefix}") - num_params -= sum( - sum(p.numel() for p in m.parameters()) - for m in submodel.children() - if isinstance(m, torch.nn.Embedding) - ) + embedding_params = sum( + p.numel() + for m in submodel.children() + if isinstance(m, torch.nn.Embedding) + for p in m.parameters() + ) + num_params -= embedding_params return num_params🧰 Tools
🪛 Pylint (3.3.7)
[convention] 50-50: Missing function or method docstring
(C0116)
touchnet/bin/train.py (1)
332-334: Simplify dictionary iteration.Remove unnecessary
.keys()when iterating over dictionary keys.- for key in batch.keys(): + for key in batch: if batch[key] is not None and torch.is_tensor(batch[key]): batch[key] = batch[key].to(device_type)🧰 Tools
🪛 Ruff (0.11.9)
332-332: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
touchnet/__init__.py(1 hunks)touchnet/bin/train.py(7 hunks)touchnet/models/__init__.py(0 hunks)touchnet/models/llama/__init__.py(1 hunks)touchnet/utils/metrics.py(0 hunks)
💤 Files with no reviewable changes (2)
- touchnet/models/init.py
- touchnet/utils/metrics.py
🧰 Additional context used
🪛 Pylint (3.3.7)
touchnet/models/llama/__init__.py
[convention] 50-50: Missing function or method docstring
(C0116)
touchnet/bin/train.py
[convention] 238-238: Line too long (110/100)
(C0301)
[convention] 240-240: Line too long (113/100)
(C0301)
[convention] 244-244: Line too long (104/100)
(C0301)
[convention] 245-245: Line too long (104/100)
(C0301)
[convention] 250-250: Line too long (111/100)
(C0301)
[convention] 251-251: Line too long (110/100)
(C0301)
[convention] 257-257: Line too long (103/100)
(C0301)
[convention] 258-258: Line too long (104/100)
(C0301)
[convention] 259-259: Line too long (119/100)
(C0301)
[convention] 345-345: Line too long (103/100)
(C0301)
[warning] 337-348: String statement has no effect
(W0105)
[convention] 425-425: Line too long (107/100)
(C0301)
[convention] 426-426: Line too long (104/100)
(C0301)
touchnet/__init__.py
[error] 6-6: Unable to import 'transformers'
(E0401)
[error] 7-7: Unable to import 'transformers.models.llama'
(E0401)
[error] 8-9: Unable to import 'transformers.models.qwen2_audio'
(E0401)
[warning] 11-11: Module import itself
(W0406)
🪛 Ruff (0.11.9)
touchnet/bin/train.py
332-332: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
🔇 Additional comments (5)
touchnet/__init__.py (2)
30-109: Excellent centralization of model training specifications!Moving the training specifications from individual model modules to a central location improves maintainability and discoverability. The TODO comments clearly document the current limitations for
qwen2_audioandkimi_audiomodels.
112-115: Proper integration with transformers library.The model registration with
AutoConfigandAutoModelForCausalLMfollows the correct pattern and ensures no duplicate registrations withexist_ok=False.touchnet/bin/train.py (3)
126-138: Good defensive programming for tokenizer construction.The conditional checking for special token attributes before passing them to the tokenizer builder prevents potential AttributeErrors and makes the code more robust across different model configurations.
238-266: Excellent documentation of the parallelization process and good practice for training stability.The detailed comments explaining DTensor transformations during parallelization are extremely helpful for understanding the complex process. Converting model parameters to float32 after parallelization is a good practice for better convergence.
🧰 Tools
🪛 Pylint (3.3.7)
[convention] 238-238: Line too long (110/100)
(C0301)
[convention] 240-240: Line too long (113/100)
(C0301)
[convention] 244-244: Line too long (104/100)
(C0301)
[convention] 245-245: Line too long (104/100)
(C0301)
[convention] 250-250: Line too long (111/100)
(C0301)
[convention] 251-251: Line too long (110/100)
(C0301)
[convention] 257-257: Line too long (103/100)
(C0301)
[convention] 258-258: Line too long (104/100)
(C0301)
[convention] 259-259: Line too long (119/100)
(C0301)
419-430: Clean batch unpacking implementation.Using
.pop()to extract values that shouldn't be passed to the model's forward call is a good practice. It makes the code cleaner and prevents unnecessary data from being passed downstream.🧰 Tools
🪛 Pylint (3.3.7)
[convention] 425-425: Line too long (107/100)
(C0301)
[convention] 426-426: Line too long (104/100)
(C0301)
Summary by CodeRabbit
New Features
Improvements
Refactor