Thanks to visit codestin.com
Credit goes to github.com

Skip to content

codec fix: DDP logic and dead code revival logic#6284

Merged
ftshijt merged 3 commits intoespnet:masterfrom
whr-a:pr_codec
Nov 3, 2025
Merged

codec fix: DDP logic and dead code revival logic#6284
ftshijt merged 3 commits intoespnet:masterfrom
whr-a:pr_codec

Conversation

@whr-a
Copy link
Contributor

@whr-a whr-a commented Oct 28, 2025

What did you change?

  • Corrected the "dead code" revival logic within the EuclideanCodebook's EMA update.
  • Added DDP (Distributed Data Parallel) synchronization (all_reduce) for all codebook updates, including K-Means init, EMA statistics (cluster_size, embed_sum), and dead code sampling.

Why did you make this change?

  1. To fix a EMA bug: The previous logic for reviving dead code was non-functional. When a dead code (self.embed) was replaced with a new vector, its corresponding EMA state (self.embed_avg, self.cluster_size) was not reset. This caused the new vector to be immediately overwritten by a stale value (calculated from the old, dead state) in the same forward pass, preventing the codebook from recovering from collapse.
  2. To fix DDP training: In the previous module, EMA updates (for cluster_size and embed_sum) were calculated using only the local batch on each worker, also with no step-by-step synchronization. Synchronization only occurred intermittently during initialization (init_embed_) and code expiration (expire_codes_). broadcast_tensors was used to copy the buffers from Rank 0 to all other workers, discarding the updates computed by other ranks. This PR changes the logic to use all_reduce on these statistics on every training step, ensuring the EMA update is calculated using the full global batch and keeping codebooks consistent across all workers.

Is your PR small enough?

yes

Additional Context

The logic in https://github.com/cisco-open/espnet/blob/master/espnet2/gan_codec/shared/quantizer/modules/core_vq.py is fully referenced.
Also compared to https://github.com/XiaomiMiMo/MiMo-Audio/blob/main/src/mimo_audio_tokenizer/quantization.py
The code has been validated for correctness through training.

@dosubot dosubot bot added size:L This PR changes 100-499 lines, ignoring generated files. Bug bug should be fixed Codec labels Oct 28, 2025
@mergify mergify bot added the ESPnet2 label Oct 28, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant and necessary fixes for DDP training and the dead code revival logic within the EuclideanCodebook. The implementation of all_reduce to synchronize statistics across distributed workers is a crucial correction for DDP training. The fix for the EMA update logic, which now correctly resets state for revived codebook vectors, resolves a critical bug that prevented the codebook from recovering. The overall changes are well-reasoned and address the described issues effectively. I have two suggestions rated as 'high' severity to improve numerical stability and code simplicity in the new distributed logic.

Comment on lines +96 to +115
return [size if i == dim else s for i, s in enumerate(shape)]

means = sample_vectors(samples, num_clusters)

def sample_multinomial(total_count, probs):
device = probs.device
probs = probs.cpu()

total_count = probs.new_full((), total_count)
remainder = probs.new_ones(())
sample = torch.empty_like(probs, dtype=torch.long)

num_probs = len(probs)

for i, prob in enumerate(probs):
is_last = i == (num_probs - 1)

s = (
torch.binomial(total_count, prob / remainder)
if not is_last
else total_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of sample_multinomial is functionally correct but overly complex and potentially fragile due to manual floating-point arithmetic for re-normalizing probabilities (prob / remainder). A much simpler, more robust, and idiomatic way to achieve this is by using torch.distributions.Multinomial. This will also avoid the unnecessary device transfer to CPU and back.

Suggested change
return [size if i == dim else s for i, s in enumerate(shape)]
means = sample_vectors(samples, num_clusters)
def sample_multinomial(total_count, probs):
device = probs.device
probs = probs.cpu()
total_count = probs.new_full((), total_count)
remainder = probs.new_ones(())
sample = torch.empty_like(probs, dtype=torch.long)
num_probs = len(probs)
for i, prob in enumerate(probs):
is_last = i == (num_probs - 1)
s = (
torch.binomial(total_count, prob / remainder)
if not is_last
else total_count
from torch.distributions import Multinomial
dist = Multinomial(total_count, probs=probs)
return dist.sample().to(dtype=torch.long)

Comment on lines +184 to 185
buckets = torch.argmax(dists, dim=-1)
bins = torch.bincount(buckets, minlength=num_clusters)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current order of operations in the k-means update step (division then all_reduce) is mathematically correct but can be numerically unstable. Performing all_reduce on the summed vectors before dividing by the counts is generally more robust as it avoids potential precision loss from intermediate divisions, especially when cluster counts (bins) are large.

Suggested change
buckets = torch.argmax(dists, dim=-1)
bins = torch.bincount(buckets, minlength=num_clusters)
all_reduce_fn(new_means)
new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')

@sw005320 sw005320 requested a review from ftshijt October 28, 2025 15:24
@sw005320 sw005320 added this to the v.202512 milestone Oct 28, 2025
@codecov
Copy link

codecov bot commented Oct 28, 2025

Codecov Report

❌ Patch coverage is 54.44444% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 56.45%. Comparing base (27d4ab8) to head (b0be59b).
⚠️ Report is 59 commits behind head on master.

Files with missing lines Patch % Lines
...net2/gan_codec/shared/quantizer/modules/core_vq.py 54.44% 41 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6284      +/-   ##
==========================================
- Coverage   56.49%   56.45%   -0.04%     
==========================================
  Files         896      896              
  Lines       84814    84881      +67     
==========================================
+ Hits        47914    47923       +9     
- Misses      36900    36958      +58     
Flag Coverage Δ
test_integration_espnet2 46.83% <54.44%> (+0.02%) ⬆️
test_integration_espnetez 36.93% <ø> (ø)
test_python_espnet2 50.90% <54.44%> (-0.03%) ⬇️
test_python_espnetez 12.72% <0.00%> (-0.02%) ⬇️
test_utils 18.77% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@ftshijt ftshijt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the update and for catching the issue. I have two questions related to this setup as follows:

self.embed.data.copy_(embed_normalized)
self.update_ema()

self.expire_codes_(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If expire codes are in later stage, shall we do another all_reduce to keep all gpus the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK,

  1. expire_codes_ determines dead codes by checking self.cluster_size < self.threshold_ema_dead_code, and self.cluster_size has just been synchronized across all GPUs, so the calculated mask is identical.
  2. The replacement vectors are determined by self.sample_fn, which in the DDP environment is sample_vectors_distributed. This function is designed for distributed sampling and ensures every GPU gets the exact same new vectors. Though each GPU will handle the dead codes independently, the resulting codebook will still remain synchronized."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the update. It looks good to me

@ftshijt ftshijt merged commit d93495b into espnet:master Nov 3, 2025
34 of 35 checks passed
@ftshijt
Copy link
Collaborator

ftshijt commented Nov 3, 2025

Thanks for your great contribution!

@Fhrozen Fhrozen modified the milestones: v.202512, v.202511 Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Bug bug should be fixed Codec ESPnet2 size:L This PR changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants