🪙 [Experimental] Support GSPO-token#3820
Conversation
|
Thanks for this. Since GSPO-token is a generalized version of vanilla GSPO, I suggest we fully transition to GSPO-token instead of supporting both versions. Consequently, we would rename/remove |
trl/trainer/grpo_trainer.py
Outdated
| elif self.importance_sampling_level == 'sequence_token': | ||
| # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] | ||
| seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) | ||
| seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient |
There was a problem hiding this comment.
| seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient | |
| seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient |
There was a problem hiding this comment.
(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like
if self.importance_sampling_level != 'token'
There was a problem hiding this comment.
make sense, so shall we move the invalid value check for importance_sampling_level into the model parameter initialization?
trl/trainer/grpo_trainer.py
Outdated
| elif self.importance_sampling_level == 'sequence_token': | ||
| # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] | ||
| seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) | ||
| seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient |
There was a problem hiding this comment.
(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like
if self.importance_sampling_level != 'token'
Agreed to keep GSPO-token. Should we retain this parameter for compatibility with previous usage, or introduce an additional parameter instead? Which is better? |
|
@qgallouedec @lewtun @edbeeching @kashif If there are any concerns or suggestions, please feel free to let me know. Thank you very much in advance |
imo it should be removed, however, since it's already been published as part of TRL v0.20, we may need to keep it for backward comp. I can't speak to it myself, so I'll leave it to someone else to decide. |
|
Thanks for the contribution, and apologies for the delay in reviewing it. After reading the paper, I don’t think this PR fully achieves GSPO-token. This variation is most relevant when you have a fine-grained advantage— i.e. when |
Sure. As I mentioned, when there is no fine-grained advantage, the gspo-token gradient is equivalent to the original implementation. However, do we need to implement this algorithm in advance to accommodate possible future fine-grained advantage, or to make it easier for downstream users to implement their own customized fine-grained advantage? Anyway, thank you for your review. |
|
Any feedback before I close this PR, or should we go ahead and merge it? |
|
Please leave it open, we are working hard to provide fast review for all the PRs 🙏 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: LeonEricsson <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
Support for GSPO-token as described in GSPO paper, Section 4.3.
related issue: #3811
GSPO
$w_{i}^{\mathrm{GSPO}} = \left[ \frac{\pi_{\theta}(y_i \mid x)}{\pi_{\theta_{\mathrm{old}}}(y_i \mid x)} \right]^{\frac{1}{|y_i|}} = \exp(\frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}}(y_{i, t} \mid x, y_{i, <t})})$
GSPO-token
$w_{i, t}^{\mathrm{GSPO_token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}$