Terry/parallelize spk emb extraction#6227
Conversation
This script extracts speaker embeddings in parallel using specified toolkits. It includes argument parsing, model loading, and batch processing for audio files.
Fix the conditional check for use_spk_embed.
This script sorts standard Kaldi data files and aligns embedding SCP files based on provided directories.
There was a problem hiding this comment.
Code Review
This pull request introduces parallel processing for speaker embedding extraction, which is a great improvement for performance. The changes include a new Python script for parallel extraction, a utility script for sorting embedding files, and an update to the TTS recipe to use it.
My review focuses on ensuring the correctness and efficiency of the new parallel implementation. I've identified a critical issue with incorrect arguments being passed to the new sorting script in tts.sh. Additionally, in the parallel extraction script, there's a significant performance issue with the rawnet implementation which is not properly batched, and a race condition in the audio loading logic. Addressing these points will make the parallel implementation more robust and efficient.
| elif self.toolkit == "rawnet": | ||
| outs = [] | ||
| for w in wav_list: | ||
| if torch.is_tensor(w): | ||
| w = w.detach().cpu().numpy() | ||
| outs.append( | ||
| self._rawnet_extract_embd(np.asarray(w, dtype=np.float32)) | ||
| ) | ||
| return np.asarray(outs, dtype=np.float32) |
There was a problem hiding this comment.
The current implementation for rawnet in extract_batch processes utterances sequentially within a loop. This underutilizes the GPU and negates the benefits of batching, creating a performance bottleneck. A true batched implementation should be used, where segments from all utterances in wav_list are collected and processed in a single forward pass on the GPU.
elif self.toolkit == "rawnet":
n_samples = 48000
n_segments = 10
all_segments = []
utterance_segment_counts = []
for w in wav_list:
if torch.is_tensor(w):
w = w.detach().cpu().numpy()
audio = np.asarray(w, dtype=np.float32)
if len(audio.shape) > 1:
raise ValueError(
"RawNet3 supports mono input only."
f"Input data has a shape of {audio.shape}."
)
if len(audio) < n_samples:
shortage = n_samples - len(audio) + 1
audio = np.pad(audio, (0, shortage), "wrap")
startframe = np.linspace(0, len(audio) - n_samples, num=n_segments)
segments = [audio[int(asf) : int(asf) + n_samples] for asf in startframe]
all_segments.extend(segments)
utterance_segment_counts.append(len(segments))
if not all_segments:
return np.array([], dtype=np.float32)
stacked_segments = torch.from_numpy(
np.stack(all_segments, axis=0).astype(np.float32)
).to(self.device)
with torch.no_grad():
output = self.model(stacked_segments)
output_np = output.detach().cpu().numpy()
embs = []
start_idx = 0
for count in utterance_segment_counts:
end_idx = start_idx + count
utterance_embs = output_np[start_idx:end_idx]
embs.append(utterance_embs.mean(axis=0))
start_idx = end_idx
return np.asarray(embs, dtype=np.float32)There was a problem hiding this comment.
This is because rawnet does not accept a batch as a single input. I tried the suggested approach during development and there were errors. Please let me know if my understanding is incorrect or a better approach can be used.
There was a problem hiding this comment.
Thanks for clarification. I think the current solution is fine.
egs2/TEMPLATE/asr1/pyscripts/utils/extract_spk_embed_parallel.py
Outdated
Show resolved
Hide resolved
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Added a lock for thread-safe access to resamplers.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #6227 +/- ##
===========================================
+ Coverage 20.63% 55.82% +35.19%
===========================================
Files 95 889 +794
Lines 10347 84275 +73928
===========================================
+ Hits 2135 47049 +44914
- Misses 8212 37226 +29014
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
This pull request introduces a new utility script to ensure that speaker embedding SCP files are sorted and aligned with the main data files, and integrates this sorting step into the TTS data preparation pipeline. This helps prevent mismatches between speaker embeddings and utterance IDs during training and evaluation. Speaker embedding sorting and integration:
|
| # for speaker in tqdm(spk2utt): | ||
| # spk_embeddings = list() | ||
| # for utt in spk2utt[speaker]: | ||
| # in_sr, wav = wav_scp[utt] | ||
| # # Speaker Embedding | ||
| # embeds = spk_embed_extractor(wav, in_sr) | ||
| # writer_utt[utt] = np.squeeze(embeds) | ||
| # spk_embeddings.append(embeds) | ||
|
|
||
| # # Speaker Normalization | ||
| # embeds = np.mean(np.stack(spk_embeddings, 0), 0) | ||
| # writer_spk[speaker] = embeds | ||
| # Build flat work list (speaker, utt) |
There was a problem hiding this comment.
what is your intention of keeping these comment-out lines?
If needed, please explain
If not, you can delete them
|
Can you make a test script? |
ftshijt
left a comment
There was a problem hiding this comment.
Thanks @ZhuoyanTao for supporting the parallel version!
| elif self.toolkit == "rawnet": | ||
| outs = [] | ||
| for w in wav_list: | ||
| if torch.is_tensor(w): | ||
| w = w.detach().cpu().numpy() | ||
| outs.append( | ||
| self._rawnet_extract_embd(np.asarray(w, dtype=np.float32)) | ||
| ) | ||
| return np.asarray(outs, dtype=np.float32) |
There was a problem hiding this comment.
Thanks for clarification. I think the current solution is fine.
| @@ -0,0 +1,401 @@ | |||
| #!/usr/bin/env python3 | |||
There was a problem hiding this comment.
Could you please also update the tts.sh correspondingly to use the parallel version (or let user to select which version to use)?
Removed commented-out code for speaker embeddings processing.
|
Any update? |
This script runs sequential and parallel speaker embedding extraction, creates synthetic audio data, and compares the outputs.
for more information, see https://pre-commit.ci
Added support for parallel speaker embedding extraction with configurable parameters.
I have added a .sh test script, which tests on a dummy dataset. The script passes with no errors. Please let me know if any further modifications are necessary. Thanks! |
|
@ftshijt, can you check his update? |
ftshijt
left a comment
There was a problem hiding this comment.
Many thanks for your update! The overall design is great and details are good, except for the minor comment below. We can merge it after the minor comment is resolved.
egs2/TEMPLATE/tts1/tts.sh
Outdated
| if "${use_spk_embed}"; then | ||
| log "Fixing order of speaker-embed scp to match text" | ||
| scripts/utils/sort_spk_embed_scp.sh "${dumpdir}" "${spk_embed_tag}" | ||
| fi |
There was a problem hiding this comment.
Is there a specific reason to separate the "fix order" and "extraction" into different stage? It might be better to include them together in a single stage (to avoid issues for people choosing not to run stage 4)
…ence of spk emb files to avoid confusing bugs I added use spk emb at the end instead of inside each case during extraction due to the multiple branches of logic in stage 3
What did you change?
Parallelize speaker embedding extraction for TTS recipes
Why did you make this change?
Speaker embedding extraction used to run sequentially, and took 40+h for dataset with many speakers
Is your PR small enough?
yes
Additional Context