codec fix: DDP logic and dead code revival logic#6284
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces significant and necessary fixes for DDP training and the dead code revival logic within the EuclideanCodebook. The implementation of all_reduce to synchronize statistics across distributed workers is a crucial correction for DDP training. The fix for the EMA update logic, which now correctly resets state for revived codebook vectors, resolves a critical bug that prevented the codebook from recovering. The overall changes are well-reasoned and address the described issues effectively. I have two suggestions rated as 'high' severity to improve numerical stability and code simplicity in the new distributed logic.
| return [size if i == dim else s for i, s in enumerate(shape)] | ||
|
|
||
| means = sample_vectors(samples, num_clusters) | ||
|
|
||
| def sample_multinomial(total_count, probs): | ||
| device = probs.device | ||
| probs = probs.cpu() | ||
|
|
||
| total_count = probs.new_full((), total_count) | ||
| remainder = probs.new_ones(()) | ||
| sample = torch.empty_like(probs, dtype=torch.long) | ||
|
|
||
| num_probs = len(probs) | ||
|
|
||
| for i, prob in enumerate(probs): | ||
| is_last = i == (num_probs - 1) | ||
|
|
||
| s = ( | ||
| torch.binomial(total_count, prob / remainder) | ||
| if not is_last | ||
| else total_count |
There was a problem hiding this comment.
The current implementation of sample_multinomial is functionally correct but overly complex and potentially fragile due to manual floating-point arithmetic for re-normalizing probabilities (prob / remainder). A much simpler, more robust, and idiomatic way to achieve this is by using torch.distributions.Multinomial. This will also avoid the unnecessary device transfer to CPU and back.
| return [size if i == dim else s for i, s in enumerate(shape)] | |
| means = sample_vectors(samples, num_clusters) | |
| def sample_multinomial(total_count, probs): | |
| device = probs.device | |
| probs = probs.cpu() | |
| total_count = probs.new_full((), total_count) | |
| remainder = probs.new_ones(()) | |
| sample = torch.empty_like(probs, dtype=torch.long) | |
| num_probs = len(probs) | |
| for i, prob in enumerate(probs): | |
| is_last = i == (num_probs - 1) | |
| s = ( | |
| torch.binomial(total_count, prob / remainder) | |
| if not is_last | |
| else total_count | |
| from torch.distributions import Multinomial | |
| dist = Multinomial(total_count, probs=probs) | |
| return dist.sample().to(dtype=torch.long) |
| buckets = torch.argmax(dists, dim=-1) | ||
| bins = torch.bincount(buckets, minlength=num_clusters) |
There was a problem hiding this comment.
The current order of operations in the k-means update step (division then all_reduce) is mathematically correct but can be numerically unstable. Performing all_reduce on the summed vectors before dividing by the counts is generally more robust as it avoids potential precision loss from intermediate divisions, especially when cluster counts (bins) are large.
| buckets = torch.argmax(dists, dim=-1) | |
| bins = torch.bincount(buckets, minlength=num_clusters) | |
| all_reduce_fn(new_means) | |
| new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #6284 +/- ##
==========================================
- Coverage 56.49% 56.45% -0.04%
==========================================
Files 896 896
Lines 84814 84881 +67
==========================================
+ Hits 47914 47923 +9
- Misses 36900 36958 +58
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ftshijt
left a comment
There was a problem hiding this comment.
Many thanks for the update and for catching the issue. I have two questions related to this setup as follows:
| self.embed.data.copy_(embed_normalized) | ||
| self.update_ema() | ||
|
|
||
| self.expire_codes_(x) |
There was a problem hiding this comment.
If expire codes are in later stage, shall we do another all_reduce to keep all gpus the same?
There was a problem hiding this comment.
I think it's OK,
expire_codes_determines dead codes by checkingself.cluster_size < self.threshold_ema_dead_code, andself.cluster_sizehas just been synchronized across all GPUs, so the calculated mask is identical.- The replacement vectors are determined by
self.sample_fn, which in the DDP environment issample_vectors_distributed. This function is designed for distributed sampling and ensures every GPU gets the exact same new vectors. Though each GPU will handle the dead codes independently, the resulting codebook will still remain synchronized."
There was a problem hiding this comment.
Thanks for sharing the update. It looks good to me
|
Thanks for your great contribution! |
What did you change?
EuclideanCodebook's EMA update.all_reduce) for all codebook updates, including K-Means init, EMA statistics (cluster_size,embed_sum), and dead code sampling.Why did you make this change?
self.embed) was replaced with a new vector, its corresponding EMA state (self.embed_avg,self.cluster_size) was not reset. This caused the new vector to be immediately overwritten by a stale value (calculated from the old, dead state) in the same forward pass, preventing the codebook from recovering from collapse.cluster_sizeandembed_sum) were calculated using only thelocalbatch on each worker, also with no step-by-step synchronization. Synchronization only occurred intermittently during initialization (init_embed_) and code expiration (expire_codes_).broadcast_tensorswas used to copy the buffers from Rank 0 to all other workers, discarding the updates computed by other ranks. This PR changes the logic to useall_reduceon these statistics on every training step, ensuring the EMA update is calculated using the full global batch and keeping codebooks consistent across all workers.Is your PR small enough?
yes
Additional Context
The logic in https://github.com/cisco-open/espnet/blob/master/espnet2/gan_codec/shared/quantizer/modules/core_vq.py is fully referenced.
Also compared to https://github.com/XiaomiMiMo/MiMo-Audio/blob/main/src/mimo_audio_tokenizer/quantization.py
The code has been validated for correctness through training.