Codec codebase bug fixes: detach() in RVQ residual and target_bandwidth in inference#6268
Codec codebase bug fixes: detach() in RVQ residual and target_bandwidth in inference#6268ftshijt merged 2 commits intoespnet:masterfrom
detach() in RVQ residual and target_bandwidth in inference#6268Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces two important bug fixes for the codec codebase. The first fix correctly applies .detach() when calculating the residual in the Residual Vector Quantizer, which is crucial for stable training. The second fix enables passing target_bandwidth during inference.
The changes for both fixes are well-implemented. However, I've identified a potential issue in the inference methods of both DAC and SoundStream models. Passing **kwargs directly to the generator's encode method could cause a TypeError if unexpected arguments are provided. I've suggested a small refactoring to call the model's own encode method instead, which handles keyword arguments safely and improves code reuse. Overall, this is a good set of fixes.
|
|
||
| """ | ||
| codec = self.generator.encode(x) | ||
| codec = self.generator.encode(x, **kwargs) |
There was a problem hiding this comment.
Calling self.generator.encode(x, **kwargs) directly is risky as it will raise a TypeError if kwargs contains any key other than target_bw. The inference method's signature allows for any keyword arguments, but DACGenerator.encode is more restrictive. It's safer to call self.encode(x, **kwargs), which correctly filters the keyword arguments, ensuring only target_bw is passed along. This also improves code reuse.
| codec = self.generator.encode(x, **kwargs) | |
| codec = self.encode(x, **kwargs) |
|
|
||
| """ | ||
| codec = self.generator.encode(x) | ||
| codec = self.generator.encode(x, **kwargs) |
There was a problem hiding this comment.
Calling self.generator.encode(x, **kwargs) directly can lead to a TypeError if kwargs includes keys not expected by SoundStreamGenerator.encode (which only accepts target_bw). To prevent potential crashes and improve code reuse, it's better to call self.encode(x, **kwargs). The self.encode method is designed to safely handle arbitrary keyword arguments by extracting only the relevant ones.
| codec = self.generator.encode(x, **kwargs) | |
| codec = self.encode(x, **kwargs) |
for more information, see https://pre-commit.ci
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #6268 +/- ##
===========================================
+ Coverage 46.53% 56.77% +10.24%
===========================================
Files 542 889 +347
Lines 49601 84363 +34762
===========================================
+ Hits 23080 47899 +24819
- Misses 26521 36464 +9943
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Oh, this sounds critical. |
|
Thanks for your fixing! The fixes look great to me. |
detach()should be applied. Refer to Cisco's fix: core_vq.py, and Moshi's implementation: core_vq.pytarget_bandwidthduring inference