Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: Native CGGR support for SFTTrainer (closes #3884)#3891

Open
Wilbatronic wants to merge 9 commits intounslothai:mainfrom
MinimaML:feature/cggr-integration
Open

feat: Native CGGR support for SFTTrainer (closes #3884)#3891
Wilbatronic wants to merge 9 commits intounslothai:mainfrom
MinimaML:feature/cggr-integration

Conversation

@Wilbatronic
Copy link

Summary
This PR implements native support for Confidence-Gated Gradient Routing (CGGR). This technique enables a selective backward pass by identifying "easy" tokens early in the model and masking their labels, allowing the backward pass to focus only on informative tokens.

By leveraging CGGR, users can achieve significant throughput increases (tested up to 1.59x) on consumer hardware by utilizing the saved memory to scale batch sizes beyond what is normally possible.

Key Features
Efficiency: Keeps only the top N% hardest tokens for gradient computation based on entropy or margin scoring.
Zero Kernel Changes: Directly compatible with Unsloth's optimized Fast_CrossEntropyLoss. By setting easy tokens to ignore_index=-100, the existing kernels naturally skip gradient computation for those tokens.
Zero Extra Memory: The TruncatedRouter uses the first few layers of the existing model and shares weights with the parent, consuming no additional VRAM.
GPU-Native Optimization: All statistics and masking logic are vectorized on the GPU to avoid CPU-GPU synchronization bottlenecks during the training loop.
Performance (Validated on RTX 3060 12GB)
Tested using SmolLM2-135M at an equal memory budget (~10GB VRAM):

Baseline (Batch 8): 8.3k tokens/sec at 9.13 GB
CGGR (Batch 32): 10.5k tokens/sec at 10.23 GB
Result: +27% Throughput increase by enabling 4x larger batch size.
Note: Higher gains are expected with larger models and longer sequence lengths where memory is the primary bottleneck.

Usage
CGGR can be enabled with a single line after SFTTrainer initialization:

python
from unsloth.cggr import CGGRUnslothBridge
trainer = SFTTrainer(model=model, ...)
CGGRUnslothBridge.patch_trainer(
trainer,
min_tokens_ratio=0.25, # Keep 25% hardest tokens
num_router_layers=2, # Use first 2 layers for routing
warmup_steps=10 # Train normally for 10 steps first
)

Wilbatronic and others added 9 commits January 14, 2026 17:48
- Add unsloth/cggr/ module with:
  - router.py: TruncatedRouter for lightweight difficulty scoring
  - bridge.py: CGGRUnslothBridge for trainer patching and label masking
  - __init__.py: Conditional imports with graceful fallback

Key optimizations:
  - inference_mode() instead of no_grad() for faster execution
  - Fully vectorized label masking (no Python loops)
  - log_softmax for numerically stable entropy computation
  - Efficient fused operations in scoring

- Leverages existing -100 ignore_index in Fast_CrossEntropyLoss
- Zero kernel changes required
- Optional dependency: pip install cggr

Usage:
  from unsloth.cggr import CGGRUnslothBridge
  CGGRUnslothBridge.patch_trainer(trainer, min_tokens_ratio=0.25)

Expected ~1.5-2x backward pass speedup by only computing gradients for hard tokens.
- benchmark_cggr.py: Full training comparison (requires unsloth_zoo)
- microbenchmark_cggr.py: Component-level performance tests
- benchmark_backward_savings.py: Demonstrates label masking savings

Tested on RTX 3060 12GB, showing backward pass improvements scale with masking ratio.
Benchmark results (SmolLM2-135M on RTX 3060 12GB):

SAME MEMORY BUDGET (key metric):
- Baseline: batch=8, 7507 tok/s, 9.68GB
- CGGR:     batch=32, 11905 tok/s, 8.61GB
- THROUGHPUT SPEEDUP: 1.59x (+58.6%)

SAME BATCH:
- Baseline: 550ms/step, 9.68GB
- CGGR:     434ms/step, 3.19GB
- Per-step speedup: 1.27x
- Memory savings: 67%

Files:
- benchmark_throughput.py: Same-memory-budget comparison (key!)
- benchmark_smollm.py: Same-batch comparison
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Wilbatronic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Confidence-Gated Gradient Routing (CGGR) into the SFTTrainer, a technique designed to optimize the training process of large language models. By selectively computing gradients only for the most 'challenging' tokens, it significantly reduces memory footprint and computational overhead. This allows users to train with larger batch sizes, leading to substantial throughput improvements and more efficient utilization of GPU resources, particularly on consumer hardware.

Highlights

  • Native CGGR Integration: Introduces native support for Confidence-Gated Gradient Routing (CGGR) into Unsloth, enabling selective backpropagation by identifying and masking 'easy' tokens.
  • Performance Enhancement: Achieves significant throughput increases (up to 1.59x tested) by allowing larger batch sizes due to reduced memory consumption from selective gradient computation.
  • Memory Efficiency: The TruncatedRouter used for difficulty scoring shares weights with the main model, consuming zero additional VRAM.
  • Seamless Compatibility: Directly compatible with Unsloth's optimized Fast_CrossEntropyLoss and requires no kernel changes, as easy tokens are set to ignore_index=-100.
  • GPU-Native Optimization: All statistics and masking logic are vectorized and processed on the GPU to prevent CPU-GPU synchronization bottlenecks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces native support for Confidence-Gated Gradient Routing (CGGR), which is an excellent feature for improving training throughput. The implementation is well-designed, with a clear separation of concerns and good use of GPU-native operations to maintain performance. I've provided a couple of suggestions to further improve the code's robustness and performance by avoiding an unnecessary CPU-GPU synchronization and handling configuration options more strictly.

Comment on lines +112 to +154
if self.scoring == "entropy":
# High entropy = uncertain = hard
# Use log_softmax for numerical stability (single fused kernel)
log_probs = F.log_softmax(logits, dim = -1)
probs = log_probs.exp()
scores = -torch.sum(probs * log_probs, dim = -1)
elif self.scoring == "margin":
# Small margin between top-2 = hard
# topk is efficient - only partial sort needed
top2 = torch.topk(logits, k = 2, dim = -1).values
scores = -(top2[..., 0] - top2[..., 1]) # Negative margin (high = hard)
elif self.scoring == "loss":
# High loss = hard - directly compute per-token loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
scores = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction = "none",
ignore_index = -100,
).view(shift_labels.shape)
# Pad to match original sequence length
scores = F.pad(scores, (0, 1), value = 0)
else: # combined - efficient fused computation
# Compute log_softmax once (fused kernel)
log_probs = F.log_softmax(logits, dim = -1)
probs = log_probs.exp()

# Entropy from log_probs (reuse computation)
entropy = -torch.sum(probs * log_probs, dim = -1)

# Margin from topk
top2 = torch.topk(logits, k = 2, dim = -1).values
margin = top2[..., 0] - top2[..., 1]

# Normalize and combine - use in-place operations where possible
entropy_mean = entropy.mean()
entropy_std = entropy.std() + 1e-10
margin_mean = margin.mean()
margin_std = margin.std() + 1e-10

# Combined score: high entropy OR small margin = hard
scores = (entropy - entropy_mean) / entropy_std - (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change makes the scoring strategy selection more robust by explicitly checking for each supported method and raising a ValueError for any unsupported ones. The current implementation uses a final else block that would catch any invalid strategy name (e.g., a typo) and default to the 'combined' logic, which could lead to unexpected behavior.

        if self.scoring == "entropy":
            # High entropy = uncertain = hard
            # Use log_softmax for numerical stability (single fused kernel)
            log_probs = F.log_softmax(logits, dim=-1)
            probs = log_probs.exp()
            scores = -torch.sum(probs * log_probs, dim=-1)
        elif self.scoring == "margin":
            # Small margin between top-2 = hard
            # topk is efficient - only partial sort needed
            top2 = torch.topk(logits, k=2, dim=-1).values
            scores = -(top2[..., 0] - top2[..., 1])  # Negative margin (high = hard)
        elif self.scoring == "loss":
            # High loss = hard - directly compute per-token loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            scores = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                reduction="none",
                ignore_index=-100,
            ).view(shift_labels.shape)
            # Pad to match original sequence length
            scores = F.pad(scores, (0, 1), value=0)
        elif self.scoring == "combined":  # combined - efficient fused computation
            # Compute log_softmax once (fused kernel)
            log_probs = F.log_softmax(logits, dim=-1)
            probs = log_probs.exp()
            
            # Entropy from log_probs (reuse computation)
            entropy = -torch.sum(probs * log_probs, dim=-1)
            
            # Margin from topk
            top2 = torch.topk(logits, k=2, dim=-1).values
            margin = top2[..., 0] - top2[..., 1]
            
            # Normalize and combine - use in-place operations where possible
            entropy_mean = entropy.mean()
            entropy_std = entropy.std() + 1e-10
            margin_mean = margin.mean()
            margin_std = margin.std() + 1e-10
            
            # Combined score: high entropy OR small margin = hard
            scores = (entropy - entropy_mean) / entropy_std - (margin - margin_mean) / margin_std
        else:
            raise ValueError(f"Unsupported scoring strategy: {self.scoring}")

scores_for_threshold.masked_fill_(~valid_mask, float("-inf"))

# Count valid tokens per sequence
valid_counts = valid_mask.sum(dim = 1) # [batch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .item() call here introduces a CPU-GPU synchronization point within the training loop. While the performance impact might be small, it's best to avoid unnecessary syncs in performance-critical code. A single-element tensor can be used directly in a boolean context, so the .item() call is not needed to check if max_valid > 0.

Suggested change
valid_counts = valid_mask.sum(dim = 1) # [batch]
max_valid = valid_counts.max()

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 38db6c4943

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +157 to +162
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask = mask_input,
position_ids = position_ids,
use_cache = False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid passing position_ids to GPT-2 blocks

When the router is built from model.transformer.h (GPT‑2 path from _get_layers), each layer is a GPT2Block whose forward signature does not accept position_ids; passing position_ids here will raise a TypeError as soon as a GPT‑2 model is used with CGGR. This makes the new feature unusable for GPT‑2 style models even though they are explicitly supported by the layer/embedding discovery logic.

Useful? React with 👍 / 👎.

@Datta0
Copy link
Collaborator

Datta0 commented Jan 19, 2026

Hey @Wilbatronic do you happen to have comparison benchmarks on a bigger model? Perhaps something of 4B-8B ish size?

@Wilbatronic
Copy link
Author

No sadly, I don't have the compute to test that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants