feat: Native CGGR support for SFTTrainer (closes #3884)#3891
feat: Native CGGR support for SFTTrainer (closes #3884)#3891Wilbatronic wants to merge 9 commits intounslothai:mainfrom
Conversation
- Add unsloth/cggr/ module with: - router.py: TruncatedRouter for lightweight difficulty scoring - bridge.py: CGGRUnslothBridge for trainer patching and label masking - __init__.py: Conditional imports with graceful fallback Key optimizations: - inference_mode() instead of no_grad() for faster execution - Fully vectorized label masking (no Python loops) - log_softmax for numerically stable entropy computation - Efficient fused operations in scoring - Leverages existing -100 ignore_index in Fast_CrossEntropyLoss - Zero kernel changes required - Optional dependency: pip install cggr Usage: from unsloth.cggr import CGGRUnslothBridge CGGRUnslothBridge.patch_trainer(trainer, min_tokens_ratio=0.25) Expected ~1.5-2x backward pass speedup by only computing gradients for hard tokens.
- benchmark_cggr.py: Full training comparison (requires unsloth_zoo) - microbenchmark_cggr.py: Component-level performance tests - benchmark_backward_savings.py: Demonstrates label masking savings Tested on RTX 3060 12GB, showing backward pass improvements scale with masking ratio.
Benchmark results (SmolLM2-135M on RTX 3060 12GB): SAME MEMORY BUDGET (key metric): - Baseline: batch=8, 7507 tok/s, 9.68GB - CGGR: batch=32, 11905 tok/s, 8.61GB - THROUGHPUT SPEEDUP: 1.59x (+58.6%) SAME BATCH: - Baseline: 550ms/step, 9.68GB - CGGR: 434ms/step, 3.19GB - Per-step speedup: 1.27x - Memory savings: 67% Files: - benchmark_throughput.py: Same-memory-budget comparison (key!) - benchmark_smollm.py: Same-batch comparison
for more information, see https://pre-commit.ci
Summary of ChangesHello @Wilbatronic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Confidence-Gated Gradient Routing (CGGR) into the SFTTrainer, a technique designed to optimize the training process of large language models. By selectively computing gradients only for the most 'challenging' tokens, it significantly reduces memory footprint and computational overhead. This allows users to train with larger batch sizes, leading to substantial throughput improvements and more efficient utilization of GPU resources, particularly on consumer hardware. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces native support for Confidence-Gated Gradient Routing (CGGR), which is an excellent feature for improving training throughput. The implementation is well-designed, with a clear separation of concerns and good use of GPU-native operations to maintain performance. I've provided a couple of suggestions to further improve the code's robustness and performance by avoiding an unnecessary CPU-GPU synchronization and handling configuration options more strictly.
| if self.scoring == "entropy": | ||
| # High entropy = uncertain = hard | ||
| # Use log_softmax for numerical stability (single fused kernel) | ||
| log_probs = F.log_softmax(logits, dim = -1) | ||
| probs = log_probs.exp() | ||
| scores = -torch.sum(probs * log_probs, dim = -1) | ||
| elif self.scoring == "margin": | ||
| # Small margin between top-2 = hard | ||
| # topk is efficient - only partial sort needed | ||
| top2 = torch.topk(logits, k = 2, dim = -1).values | ||
| scores = -(top2[..., 0] - top2[..., 1]) # Negative margin (high = hard) | ||
| elif self.scoring == "loss": | ||
| # High loss = hard - directly compute per-token loss | ||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = labels[..., 1:].contiguous() | ||
| scores = F.cross_entropy( | ||
| shift_logits.view(-1, shift_logits.size(-1)), | ||
| shift_labels.view(-1), | ||
| reduction = "none", | ||
| ignore_index = -100, | ||
| ).view(shift_labels.shape) | ||
| # Pad to match original sequence length | ||
| scores = F.pad(scores, (0, 1), value = 0) | ||
| else: # combined - efficient fused computation | ||
| # Compute log_softmax once (fused kernel) | ||
| log_probs = F.log_softmax(logits, dim = -1) | ||
| probs = log_probs.exp() | ||
|
|
||
| # Entropy from log_probs (reuse computation) | ||
| entropy = -torch.sum(probs * log_probs, dim = -1) | ||
|
|
||
| # Margin from topk | ||
| top2 = torch.topk(logits, k = 2, dim = -1).values | ||
| margin = top2[..., 0] - top2[..., 1] | ||
|
|
||
| # Normalize and combine - use in-place operations where possible | ||
| entropy_mean = entropy.mean() | ||
| entropy_std = entropy.std() + 1e-10 | ||
| margin_mean = margin.mean() | ||
| margin_std = margin.std() + 1e-10 | ||
|
|
||
| # Combined score: high entropy OR small margin = hard | ||
| scores = (entropy - entropy_mean) / entropy_std - ( |
There was a problem hiding this comment.
This change makes the scoring strategy selection more robust by explicitly checking for each supported method and raising a ValueError for any unsupported ones. The current implementation uses a final else block that would catch any invalid strategy name (e.g., a typo) and default to the 'combined' logic, which could lead to unexpected behavior.
if self.scoring == "entropy":
# High entropy = uncertain = hard
# Use log_softmax for numerical stability (single fused kernel)
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
scores = -torch.sum(probs * log_probs, dim=-1)
elif self.scoring == "margin":
# Small margin between top-2 = hard
# topk is efficient - only partial sort needed
top2 = torch.topk(logits, k=2, dim=-1).values
scores = -(top2[..., 0] - top2[..., 1]) # Negative margin (high = hard)
elif self.scoring == "loss":
# High loss = hard - directly compute per-token loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
scores = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none",
ignore_index=-100,
).view(shift_labels.shape)
# Pad to match original sequence length
scores = F.pad(scores, (0, 1), value=0)
elif self.scoring == "combined": # combined - efficient fused computation
# Compute log_softmax once (fused kernel)
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
# Entropy from log_probs (reuse computation)
entropy = -torch.sum(probs * log_probs, dim=-1)
# Margin from topk
top2 = torch.topk(logits, k=2, dim=-1).values
margin = top2[..., 0] - top2[..., 1]
# Normalize and combine - use in-place operations where possible
entropy_mean = entropy.mean()
entropy_std = entropy.std() + 1e-10
margin_mean = margin.mean()
margin_std = margin.std() + 1e-10
# Combined score: high entropy OR small margin = hard
scores = (entropy - entropy_mean) / entropy_std - (margin - margin_mean) / margin_std
else:
raise ValueError(f"Unsupported scoring strategy: {self.scoring}")| scores_for_threshold.masked_fill_(~valid_mask, float("-inf")) | ||
|
|
||
| # Count valid tokens per sequence | ||
| valid_counts = valid_mask.sum(dim = 1) # [batch] |
There was a problem hiding this comment.
The .item() call here introduces a CPU-GPU synchronization point within the training loop. While the performance impact might be small, it's best to avoid unnecessary syncs in performance-critical code. A single-element tensor can be used directly in a boolean context, so the .item() call is not needed to check if max_valid > 0.
| valid_counts = valid_mask.sum(dim = 1) # [batch] | |
| max_valid = valid_counts.max() |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 38db6c4943
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for layer in self.layers: | ||
| layer_outputs = layer( | ||
| hidden_states, | ||
| attention_mask = mask_input, | ||
| position_ids = position_ids, | ||
| use_cache = False, |
There was a problem hiding this comment.
Avoid passing position_ids to GPT-2 blocks
When the router is built from model.transformer.h (GPT‑2 path from _get_layers), each layer is a GPT2Block whose forward signature does not accept position_ids; passing position_ids here will raise a TypeError as soon as a GPT‑2 model is used with CGGR. This makes the new feature unusable for GPT‑2 style models even though they are explicitly supported by the layer/embedding discovery logic.
Useful? React with 👍 / 👎.
|
Hey @Wilbatronic do you happen to have comparison benchmarks on a bigger model? Perhaps something of 4B-8B ish size? |
|
No sadly, I don't have the compute to test that. |
Summary
This PR implements native support for Confidence-Gated Gradient Routing (CGGR). This technique enables a selective backward pass by identifying "easy" tokens early in the model and masking their labels, allowing the backward pass to focus only on informative tokens.
By leveraging CGGR, users can achieve significant throughput increases (tested up to 1.59x) on consumer hardware by utilizing the saved memory to scale batch sizes beyond what is normally possible.
Key Features
Efficiency: Keeps only the top N% hardest tokens for gradient computation based on entropy or margin scoring.
Zero Kernel Changes: Directly compatible with Unsloth's optimized Fast_CrossEntropyLoss. By setting easy tokens to ignore_index=-100, the existing kernels naturally skip gradient computation for those tokens.
Zero Extra Memory: The TruncatedRouter uses the first few layers of the existing model and shares weights with the parent, consuming no additional VRAM.
GPU-Native Optimization: All statistics and masking logic are vectorized on the GPU to avoid CPU-GPU synchronization bottlenecks during the training loop.
Performance (Validated on RTX 3060 12GB)
Tested using SmolLM2-135M at an equal memory budget (~10GB VRAM):
Baseline (Batch 8): 8.3k tokens/sec at 9.13 GB
CGGR (Batch 32): 10.5k tokens/sec at 10.23 GB
Result: +27% Throughput increase by enabling 4x larger batch size.
Note: Higher gains are expected with larger models and longer sequence lengths where memory is the primary bottleneck.
Usage
CGGR can be enabled with a single line after SFTTrainer initialization:
python
from unsloth.cggr import CGGRUnslothBridge
trainer = SFTTrainer(model=model, ...)
CGGRUnslothBridge.patch_trainer(
trainer,
min_tokens_ratio=0.25, # Keep 25% hardest tokens
num_router_layers=2, # Use first 2 layers for routing
warmup_steps=10 # Train normally for 10 steps first
)