From 88f99cac7c2cd0ea19f10c72436217bd737b83b6 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Wed, 8 Oct 2025 18:45:04 -0700 Subject: [PATCH 01/29] initiate PR 01 for MT-Parakeet Signed-off-by: Weiqing Wang --- ...ech_to_text_multitalker_streaming_infer.py | 450 ++++ .../asr/data/audio_to_text_lhotse_speaker.py | 104 + nemo/collections/asr/models/__init__.py | 1 + .../asr/models/multitalker_asr_models.py | 132 ++ .../asr/models/sortformer_diar_models.py | 26 +- .../asr/modules/sortformer_modules.py | 20 + nemo/collections/asr/parts/mixins/__init__.py | 1 + nemo/collections/asr/parts/mixins/mixins.py | 2 + .../parts/mixins/multitalker_asr_mixins.py | 271 +++ .../collections/asr/parts/mixins/streaming.py | 2 + .../asr/parts/utils/asr_multispeaker_utils.py | 613 +++++- .../asr/parts/utils/data_simulation_utils.py | 141 +- .../asr/parts/utils/diarization_utils.py | 816 ++++++-- .../parts/utils/multispk_transcribe_utils.py | 1806 +++++++++++++++++ nemo/collections/common/data/lhotse/cutset.py | 20 + 15 files changed, 4135 insertions(+), 270 deletions(-) create mode 100644 examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py create mode 100644 nemo/collections/asr/data/audio_to_text_lhotse_speaker.py create mode 100644 nemo/collections/asr/models/multitalker_asr_models.py create mode 100644 nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py create mode 100644 nemo/collections/asr/parts/utils/multispk_transcribe_utils.py diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py new file mode 100644 index 000000000000..937d1c29d1fd --- /dev/null +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass, is_dataclass, field +from typing import Optional, Union, List, Tuple, Dict, Any + +import torch +import os +import pytorch_lightning as pl +from omegaconf import OmegaConf +from omegaconf import open_dict +from lhotse.dataset.collation import collate_matrices + + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer + +from copy import deepcopy +from nemo.collections.asr.parts.utils.diarization_utils import read_seglst, OnlineEvaluation +from nemo.utils import logging + +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel +from nemo.core.config import hydra_runner + +from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR, get_multi_talker_samples_from_manifest +from nemo.collections.asr.parts.utils.speaker_utils import ( +audio_rttm_map as get_audio_rttm_map, +rttm_to_labels, +) +from nemo.collections.asr.parts.utils.diarization_utils import ( +print_sentences, +get_color_palette, +write_txt, +) +from nemo.collections.asr.data.audio_to_diar_label import get_frame_targets_from_rttm, extract_frame_info_from_rttm + + +from typing import List, Optional +from dataclasses import dataclass +from collections import OrderedDict +import itertools + +import time +from functools import wraps +import math + +@dataclass +class DiarizationConfig: + # Required configs + diar_model_path: Optional[str] = None # Path to a .nemo file + diar_pretrained_name: Optional[str] = None # Name of a pretrained model + max_num_of_spks: Optional[int] = 4 + parallel_speaker_strategy: bool = True + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + num_workers: int = 8 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + log: bool = True # If True, log will be printed + + # Streaming diarization configs + streaming_mode: bool = True # If True, streaming diarization will be used. + spkcache_len: int = 188 + spkcache_refresh_rate: int = 0 + fifo_len: int = 188 + chunk_len: int = 0 + chunk_left_context: int = 0 + chunk_right_context: int = 0 + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # ASR Configs + asr_model: Optional[str] = None + device: str = 'cuda' + audio_file: Optional[str] = None + manifest_file: Optional[str] = None + use_amp: bool = True + debug_mode: bool = False + batch_size: int = 32 + chunk_size: int = -1 + shift_size: int = -1 + left_chunks: int = 2 + online_normalization: bool = False + output_path: Optional[str] = None + pad_and_drop_preencoded: bool = False + set_decoder: Optional[str] = None # ["ctc", "rnnt"] + att_context_size: Optional[list] = None + generate_realtime_scripts: bool = True + + word_window: int = 50 + sent_break_sec: float = 30.0 + fix_prev_words_count: int = 5 + update_prev_words_sentence: int = 5 + left_frame_shift: int = -1 + right_frame_shift: int = 0 + min_sigmoid_val: float = 1e-2 + discarded_frames: int = 8 + print_time: bool = True + print_sample_indices: List[int] = field(default_factory=lambda: [0]) + colored_text: bool = True + real_time_mode: bool = False + print_path: str = "./" + + ignored_initial_frame_steps: int = 5 + verbose: bool = False + + feat_len_sec: float = 0.01 + finetune_realtime_ratio: float = 0.01 + + spk_supervision: str = "diar" # ["diar", "rttm"] + binary_diar_preds: bool = False + + +def format_time(seconds): + minutes = math.floor(seconds / 60) + sec = seconds % 60 + return f"{minutes}:{sec:05.2f}" + +def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded): + # for the first step there is no need to drop any tokens after the downsampling as no caching is being used + if step_num == 0 and not pad_and_drop_preencoded: + return 0 + else: + return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + +def add_delay_for_real_time(cfg, chunk_audio, session_start_time, feat_frame_count, loop_end_time, loop_start_time): + """ + Add artificial delay for real-time mode by calculating the time difference between + the current time and the session start time.. + + Args: + cfg (DiarizationConfig): The configuration object. + """ + time_diff = max(0, (time.time() - session_start_time) - feat_frame_count * cfg.feat_len_sec) + eta_min_sec = format_time(time.time() - session_start_time) + logging.info(f"[ REAL TIME MODE ] min:sec - {eta_min_sec} " + f"Time difference for real-time mode: {time_diff:.4f} seconds") + time.sleep(max(0, (chunk_audio.shape[-1] - cfg.discarded_frames)*cfg.feat_len_sec - + (loop_end_time - loop_start_time) - time_diff * cfg.finetune_realtime_ratio)) + + +def write_seglst_file(seglst_dict_list, output_path): + if len(seglst_dict_list) == 0: + raise ValueError("seglst_dict_list is empty. No transcriptions were generated.") + with open(output_path, 'w') as f: + f.write(json.dumps(seglst_dict_list, indent=4) + '\n') + logging.info(f"Saved the transcriptions of the streaming inference in\n:{output_path}") + +def launch_serial_streaming( + cfg, + asr_model, + diar_model, + streaming_buffer, + pad_and_drop_preencoded=False, +): + streaming_buffer_iter = iter(streaming_buffer) + + multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) + feat_frame_count = 0 + + session_start_time = time.time() + for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): + drop_extra_pre_encoded = calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded) + loop_start_time = time.time() + with torch.inference_mode(): + with autocast: + with torch.no_grad(): + multispk_asr_streamer.perform_serial_streaming_stt_spk( + step_num=step_num, + chunk_audio=chunk_audio, + chunk_lengths=chunk_lengths, + is_buffer_empty=streaming_buffer.is_buffer_empty(), + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + + feat_frame_count += (chunk_audio.shape[-1] - cfg.discarded_frames) + if cfg.real_time_mode: + add_delay_for_real_time(cfg, + chunk_audio=chunk_audio, + session_start_time=session_start_time, + feat_frame_count=feat_frame_count, + loop_end_time=time.time(), + loop_start_time=loop_start_time + ) + return multispk_asr_streamer + +def launch_parallel_streaming( + cfg, + asr_model, + diar_model, + streaming_buffer, + pad_and_drop_preencoded=False, + ): + streaming_buffer_iter = iter(streaming_buffer) + multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) + + for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): + # logging.info(f"Step ID: {step_num}") + with torch.inference_mode(): + with autocast: + with torch.no_grad(): + drop_extra_pre_encoded = calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded) + multispk_asr_streamer.perform_parallel_streaming_stt_spk( + step_num=step_num, + chunk_audio=chunk_audio, + chunk_lengths=chunk_lengths, + is_buffer_empty=streaming_buffer.is_buffer_empty(), + drop_extra_pre_encoded=drop_extra_pre_encoded, + ) + return multispk_asr_streamer + + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.diar_model_path is None and cfg.diar_pretrained_name is None: + raise ValueError("Both cfg.diar_model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_file is None and cfg.manifest_file is None: + raise ValueError("Both cfg.audio_file and cfg.manifest_file cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = [0] + accelerator = 'mps' + map_location = torch.device('mps') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.diar_model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.diar_model_path, + map_location=map_location, strict=False) + elif cfg.diar_model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, + map_location=map_location) + else: + raise ValueError("cfg.diar_model_path must end with.ckpt or.nemo!") + + # Model setup for inference + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + diar_model._cfg.test_ds.manifest_filepath = cfg.manifest_file + diar_model._cfg.test_ds.batch_size = cfg.batch_size + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + diar_model = diar_model.eval() + + # Steaming mode setup + diar_model.streaming_mode = cfg.streaming_mode + diar_model.sortformer_modules.chunk_len = cfg.chunk_len + diar_model.sortformer_modules.spkcache_len = cfg.spkcache_len + diar_model.sortformer_modules.chunk_left_context = cfg.chunk_left_context + diar_model.sortformer_modules.chunk_right_context = cfg.chunk_right_context + diar_model.sortformer_modules.fifo_len = cfg.fifo_len + diar_model.sortformer_modules.log = cfg.log + diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate + + if cfg.audio_file is not None and cfg.manifest_file is not None: + logging.warning("Both audio_file and manifest_file are specified. audio_file will be used with top priority.") + input_type = "audio_file" + elif cfg.audio_file is not None: + logging.info("audio_file is specified. Using audio_file as input.") + input_type = "audio_file" + elif cfg.manifest_file is not None: + logging.info("manifest_file is specified. Using manifest_file as input.") + input_type = "manifest_file" + else: + raise ValueError("One of audio_file or manifest_file must be specified!") + + if cfg.asr_model.endswith('.nemo'): + logging.info(f"Using local ASR model from {cfg.asr_model}") + asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=cfg.asr_model) + else: + logging.info(f"Using NGC cloud ASR model {cfg.asr_model}") + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=cfg.asr_model) + + logging.info(asr_model.encoder.streaming_cfg) + if cfg.set_decoder is not None: + if hasattr(asr_model, "cur_decoder"): + asr_model.change_decoding_strategy(decoder_type=cfg.set_decoder) + else: + raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.") + + if cfg.att_context_size is not None: + if hasattr(asr_model.encoder, "set_default_att_context_size"): + asr_model.encoder.set_default_att_context_size(att_context_size=cfg.att_context_size) + else: + raise ValueError("Model does not support multiple lookaheads.") + + global autocast + autocast = torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp) + + # Initialize to avoid "possibly used before assignment" error + multispk_asr_streamer = None + + # configure the decoding config + decoding_cfg = asr_model.cfg.decoding + with open_dict(decoding_cfg): + decoding_cfg.strategy = "greedy" + decoding_cfg.preserve_alignments = False + if hasattr(asr_model, 'joint'): # if an RNNT model + decoding_cfg.greedy.max_symbols = 10 + decoding_cfg.fused_batch_size = -1 + asr_model.change_decoding_strategy(decoding_cfg) + + asr_model = asr_model.to(cfg.device) + asr_model.eval() + + # chunk_size is set automatically for models trained for streaming. + # For models trained for offline mode with full context, we need to pass the chunk_size explicitly. + if cfg.chunk_size > 0: + if cfg.shift_size < 0: + shift_size = cfg.chunk_size + else: + shift_size = cfg.shift_size + asr_model.encoder.setup_streaming_params( + chunk_size=cfg.chunk_size, left_chunks=cfg.left_chunks, shift_size=shift_size + ) + + # In streaming, offline normalization is not feasible as we don't have access to the + # whole audio at the beginning When online_normalization is enabled, the normalization + # of the input features (mel-spectrograms) are done per step It is suggested to train + # the streaming models without any normalization in the input features. + if cfg.online_normalization: + if asr_model.cfg.preprocessor.normalize not in ["per_feature", "all_feature"]: + logging.warning( + "online_normalization is enabled but the model has" + "no normalization in the feature extration part, so it is ignored." + ) + online_normalization = False + else: + online_normalization = True + + else: + online_normalization = False + + if cfg.audio_file is not None: + # Stream a single audio file + samples = [{'audio_filepath': cfg.audio_file,}] + streaming_buffer = CacheAwareStreamingAudioBuffer( + model=asr_model, + online_normalization=online_normalization, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, + ) + cfg.batch_size = len(samples) + streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1) + if cfg.parallel_speaker_strategy: + multispk_asr_streamer = launch_serial_streaming( + cfg=cfg, + asr_model=asr_model, + diar_model=diar_model, + streaming_buffer=streaming_buffer, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, + ) + + else: + multispk_asr_streamer = launch_serial_streaming( + cfg=cfg, + asr_model=asr_model, + diar_model=diar_model, + streaming_buffer=streaming_buffer, + ) + else: + # Stream audio files in a manifest file in batched mode + feat_per_sec = round(asr_model.cfg.preprocessor.window_stride * asr_model.cfg.encoder.subsampling_factor, 2) + samples, rttms_mask_mats = get_multi_talker_samples_from_manifest(cfg, manifest_file=cfg.manifest_file, feat_per_sec=feat_per_sec, max_spks=cfg.max_num_of_spks) + cfg.batch_size = len(samples) + # Note: rttms_mask_mats contains PyTorch tensors, so we pass it directly instead of storing in config + if cfg.spk_supervision == "rttm": + diar_model.add_rttms_mask_mats(rttms_mask_mats, device=asr_model.device) + + logging.info(f"Loaded {len(samples)} from the manifest at {cfg.manifest_file}.") + + streaming_buffer = CacheAwareStreamingAudioBuffer( + model=asr_model, + online_normalization=online_normalization, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, + ) + + for sample_idx, sample in enumerate(samples): + streaming_buffer.append_audio_file(sample['audio_filepath'], stream_id=-1) + logging.info(f'Added this sample to the buffer: {sample["audio_filepath"]}') + + if (sample_idx + 1) % cfg.batch_size == 0 or sample_idx == len(samples) - 1: + logging.info(f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}...") + if cfg.parallel_speaker_strategy: + multispk_asr_streamer = launch_parallel_streaming( + cfg=cfg, + asr_model=asr_model, + diar_model=diar_model, + streaming_buffer=streaming_buffer, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, + ) + else: + multispk_asr_streamer = launch_serial_streaming( + cfg=cfg, + asr_model=asr_model, + diar_model=diar_model, + streaming_buffer=streaming_buffer, + ) + streaming_buffer.reset_buffer() + + if cfg.output_path is not None and multispk_asr_streamer is not None: + if cfg.parallel_speaker_strategy: + multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples) + write_seglst_file(seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, + output_path=cfg.output_path) + else: + multispk_asr_streamer.generate_seglst_dicts_from_serial_streaming(samples=samples) + write_seglst_file(seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, + output_path=cfg.output_path) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py new file mode 100644 index 000000000000..16a4ee0463f5 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import random +from typing import Dict, Optional, Tuple +import soundfile + +import torch.utils.data +from lhotse.cut import MixedCut, MonoCut, MixTrack, PaddingCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors, collate_matrices +from lhotse.utils import compute_num_samples +from lhotse import SupervisionSet, SupervisionSegment, MonoCut, Recording, CutSet, AudioSource + +import numpy as np + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + speaker_to_target, + get_hidden_length_from_sample_length, +) + +class LhotseSpeechToTextSpkBpeDataset(torch.utils.data.Dataset): + """ + This dataset is based on BPE datasets from audio_to_text.py. It has the same functionality of LhotseSpeechToTextBpeDataset but also yield speaker target tensor. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'spk_targets': NeuralType(('B','T'), LabelsType()), + 'bg_spk_targets': NeuralType(('B','T'), LabelsType()), + } + + def __init__(self, cfg, tokenizer): + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True, num_workers=8) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = self.cfg.get('num_sample_per_mel_frame', 160) + self.num_mel_frame_per_asr_frame = self.cfg.get('num_mel_frame_per_asr_frame', 8) + self.fixed_spk_id = self.cfg.get('fixed_spk_id', None) + self.inference_mode = self.cfg.get('inference_mode', False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + + audio, audio_lens, cuts = self.load_audio(cuts) + + tokens = [] + spk_targets = [] + bg_spk_targets = [] + + if self.inference_mode: + speaker_targets = [speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame) for cut in cuts] + spk_targets = collate_matrices(speaker_targets, padding_value=0) + return audio, audio_lens, None, None, spk_targets + + for idx, cut in enumerate(cuts): + + speaker_targets, texts = speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame, return_text=True) + speaker_targets = speaker_targets.transpose(0, 1)[:len(texts)] + + target_speaker_id = random.choice(range(len(texts))) + non_target_speaker_ids = [i for i in range(len(texts)) if i != target_speaker_id] + text = texts[target_speaker_id] + speaker_target = speaker_targets[target_speaker_id] + bg_speaker_target = speaker_targets[non_target_speaker_ids].sum(dim=0) > 0 + + tokens.append(torch.as_tensor(self.tokenizer(text, cut.supervisions[0].language))) + spk_targets.append(speaker_target) + bg_spk_targets.append(bg_speaker_target) + + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=0) + spk_targets = collate_vectors(spk_targets, padding_value=0) + bg_spk_targets = collate_vectors(bg_spk_targets, padding_value=0) + + return audio, audio_lens, tokens, token_lens, spk_targets, bg_spk_targets \ No newline at end of file diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 34dead15b33d..11fd592b0f40 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -42,3 +42,4 @@ SpeechEncDecSelfSupervisedModel, ) from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE +from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel diff --git a/nemo/collections/asr/models/multitalker_asr_models.py b/nemo/collections/asr/models/multitalker_asr_models.py new file mode 100644 index 000000000000..ea7a66eb15a7 --- /dev/null +++ b/nemo/collections/asr/models/multitalker_asr_models.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from typing import Any, Dict, List, Optional +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data.audio_to_text_lhotse_speaker import LhotseSpeechToTextSpkBpeDataset + +from nemo.collections.asr.parts.mixins import ( + TranscribeConfig, + TranscriptionReturnType, +) +from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin + +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + + +class EncDecMultiTalkerRNNTBPEModel(EncDecRNNTBPEModel, SpeakerKernelMixin): + """Base class for encoder decoder RNNT-based models with subword tokenization.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + # Initialize speaker kernel functionality from mixin + self._init_speaker_kernel_config(cfg) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if config.get("use_lhotse"): + # Use open_dict to allow dynamic key addition + with open_dict(config): + config.global_rank = self.global_rank + config.world_size = self.world_size + + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseSpeechToTextSpkBpeDataset(cfg = config, tokenizer=self.tokenizer,), + ) + + def training_step(self, batch, batch_nb): + """Training step with speaker targets.""" + signal, signal_len, transcript, transcript_len, *additional_args = batch + spk_targets, bg_spk_targets = additional_args + + self.set_speaker_targets(spk_targets, bg_spk_targets) + + batch = (signal, signal_len, transcript, transcript_len) + + return super().training_step(batch, batch_nb) + + def validation_pass(self, batch, batch_idx, dataloader_idx=0): + """Validation pass with speaker targets.""" + signal, signal_len, transcript, transcript_len, *additional_args = batch + spk_targets, bg_spk_targets = additional_args + + self.set_speaker_targets(spk_targets, bg_spk_targets) + + batch = (signal, signal_len, transcript, transcript_len) + + return super().validation_pass(batch, batch_idx, dataloader_idx) + + def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): + """Transcribe forward with speaker targets.""" + signal, signal_len, transcript, transcript_len, *additional_args = batch + spk_targets, bg_spk_targets = additional_args + + self.set_speaker_targets(spk_targets, bg_spk_targets) + + batch = (signal, signal_len, transcript, transcript_len) + + return super()._transcribe_forward(batch, trcfg) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + if 'dataset_manifest' in config: + manifest_filepath = config['dataset_manifest'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'use_lhotse': True, + 'use_bucketing': False, + 'channel_selector': config.get('channel_selector', None), + 'inference_mode': self.cfg.test_ds.get('inference_mode', True), + 'fixed_spk_id': config.get('fixed_spk_id', None) + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + + return temporary_datalayer \ No newline at end of file diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 537991f81128..8d3953ff8302 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -128,6 +128,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.max_batch_dur = self._cfg.get("max_batch_dur", 20000) self.concat_and_pad_script = torch.jit.script(self.sortformer_modules.concat_and_pad) + self.rttms_mask_mats: List[torch.Tensor] = None # Used when GT diarization needs to be tested. + + def add_rttms_mask_mats(self, rttms_mask_mats, device: torch.device): + """ + Check if the rttms_mask_mats is empty then add it to the list + + Args: + rttms_mask_mats (List[torch.Tensor]): List of PyTorch tensors containing the rttms mask matrices. + """ + if self.rttms_mask_mats is None: + self.rttms_mask_mats = rttms_mask_mats.to(device) + else: + raise ValueError(f"{self.rttms_mask_mats.shape}: rttms_mask_mats already exist but new one is being added.") def _init_loss_weights(self): pil_weight = self._cfg.get("pil_weight", 0.0) @@ -304,7 +317,8 @@ def forward_infer(self, emb_seq, emb_seq_length): """ encoder_mask = self.sortformer_modules.length_to_mask(emb_seq_length, emb_seq.shape[1]) trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) - preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + _preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + preds = _preds * encoder_mask.unsqueeze(-1) return preds def _diarize_forward(self, batch: Any): @@ -704,6 +718,7 @@ def forward_streaming_step( processed_signal_length, streaming_state, total_preds, + drop_extra_pre_encoded=0, left_offset=0, right_offset=0, ): @@ -744,6 +759,10 @@ def forward_streaming_step( chunk_pre_encode_embs, chunk_pre_encode_lengths = self.encoder.pre_encode( x=processed_signal, lengths=processed_signal_length ) + # To match the output of the ASR model, we need to drop the extra pre-encoded embeddings + if drop_extra_pre_encoded > 0: + chunk_pre_encode_embs = chunk_pre_encode_embs[:, drop_extra_pre_encoded:, :] + chunk_pre_encode_lengths = chunk_pre_encode_lengths - drop_extra_pre_encoded if self.async_streaming: spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = ( @@ -811,7 +830,6 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: Returns: (dict): A dictionary containing the following training metrics. """ - targets = targets.to(preds.dtype) if preds.shape[1] < targets.shape[1]: logging.info( f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " @@ -884,7 +902,6 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: Returns: val_metrics (dict): A dictionary containing the following validation metrics """ - targets = targets.to(preds.dtype) if preds.shape[1] < targets.shape[1]: logging.info( f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " @@ -1016,7 +1033,6 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target target_lens (torch.Tensor): Lengths of target sequences. Shape: (batch_size,) """ - targets = targets.to(preds.dtype) if preds.shape[1] < targets.shape[1]: logging.info( f"WARNING! preds has less frames than targets ({preds.shape[1]} < {targets.shape[1]}). " @@ -1125,4 +1141,4 @@ def diarize( num_workers=num_workers, verbose=verbose, override_config=override_config, - ) + ) \ No newline at end of file diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 8a45b385568a..4b06c0e2978d 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -55,6 +55,26 @@ class StreamingSortformerState: mean_sil_emb = None n_sil_frames = None + def to(self, device): + if self.spkcache is not None: + self.spkcache = self.spkcache.to(device) + if self.spkcache_lengths is not None: + self.spkcache_lengths = self.spkcache_lengths.to(device) + if self.spkcache_preds is not None: + self.spkcache_preds = self.spkcache_preds.to(device) + if self.fifo is not None: + self.fifo = self.fifo.to(device) + if self.fifo_lengths is not None: + self.fifo_lengths = self.fifo_lengths.to(device) + if self.fifo_preds is not None: + self.fifo_preds = self.fifo_preds.to(device) + if self.spk_perm is not None: + self.spk_perm = self.spk_perm.to(device) + if self.mean_sil_emb is not None: + self.mean_sil_emb = self.mean_sil_emb.to(device) + if self.n_sil_frames is not None: + self.n_sil_frames = self.n_sil_frames.to(device) + class SortformerModules(NeuralModule, Exportable): """ diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py index 02378bd9d282..7ea8ca2e1584 100644 --- a/nemo/collections/asr/parts/mixins/__init__.py +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -14,6 +14,7 @@ from nemo.collections.asr.parts.mixins.asr_adapter_mixins import ASRAdapterModelMixin from nemo.collections.asr.parts.mixins.interctc_mixin import InterCTCMixin +from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.asr.parts.mixins.mixins import ( ASRAdapterModelMixin, ASRBPEMixin, diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 1f4c7406fb6b..af973be3cc4c 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -602,6 +602,7 @@ def conformer_stream_step( drop_extra_pre_encoded: int = None, return_transcription: bool = True, return_log_probs: bool = False, + bypass_pre_encode: bool = False, ): """ It simulates a forward step with caching for streaming purposes. @@ -657,6 +658,7 @@ def conformer_stream_step( cache_last_channel_len=cache_last_channel_len, keep_all_outputs=keep_all_outputs, drop_extra_pre_encoded=drop_extra_pre_encoded, + bypass_pre_encode=bypass_pre_encode, ) if isinstance(self, asr_models.EncDecCTCModel) or ( diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py new file mode 100644 index 000000000000..9c4f6eea8109 --- /dev/null +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -0,0 +1,271 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional +import torch +import torch.nn as nn +from abc import ABC, abstractmethod +from omegaconf import ListConfig + +from nemo.utils import logging + +__all__ = ['SpeakerKernelMixin'] + +def get_spk_kernel_class( + spk_kernel_type, + input_size, + d_model, + dropout=0.5 +): + if spk_kernel_type == 'ff': + return nn.Sequential(nn.Linear(input_size, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, input_size)) + elif spk_kernel_type == 'conv2d': + return + elif spk_kernel_type == 'mha': + return + +class SpeakerKernelMixin(ABC): + """ + Mixin class for models that need speaker kernel functionality. + + This mixin provides: + - Speaker kernel initialization + - Hook attachment for applying speaker kernels at specific encoder layers + - Support for both active and background speaker kernels + + Models using this mixin should have the following config parameters: + - spk_kernel_type: Type of speaker kernel ('mask', 'concat', 'sinusoidal') + - spk_kernel_layers: List of layer indices where to apply speaker kernels + - add_bg_spk_kernel: Whether to add background speaker kernels + """ + + def _init_speaker_kernel_config(self, cfg): + """ + Initialize speaker kernel configuration from model config. + + Args: + cfg: Model configuration containing speaker kernel parameters + """ + # Speaker kernel config + self.spk_kernel_type = cfg.get('spk_kernel_type', None) + self.spk_kernel_layers = cfg.get('spk_kernel_layers', [0]) + self.add_bg_spk_kernel = cfg.get('add_bg_spk_kernel', True) + + # Initialize speaker target containers + self.spk_targets = None + if self.add_bg_spk_kernel: + self.bg_spk_targets = None + + # Initialize speaker kernels + self._init_spk_kernel() + + def _init_spk_kernel(self): + """Initialize speaker kernel modules and register them to encoder layers.""" + if not isinstance(self.spk_kernel_layers, ListConfig): + if self.spk_kernel_type is not None: + raise ValueError(f"spk_kernel_layers must be a list, got {type(self.spk_kernel_layers)}") + return + + # Initialize speaker kernels for each specified layer + hidden_size = self.cfg.model_defaults.enc_hidden + self.spk_kernels = torch.nn.ModuleDict() + if self.add_bg_spk_kernel: + self.bg_spk_kernels = torch.nn.ModuleDict() + + # Create kernel for each layer index + for layer_idx in self.spk_kernel_layers: + self.spk_kernels[str(layer_idx)] = get_spk_kernel_class( + spk_kernel_type=self.spk_kernel_type, + input_size=hidden_size, + d_model=self.cfg.encoder.d_model, + dropout=0.5 + ) + if self.add_bg_spk_kernel: + self.bg_spk_kernels[str(layer_idx)] = get_spk_kernel_class( + spk_kernel_type=self.spk_kernel_type, + input_size=hidden_size, + d_model=self.cfg.encoder.d_model, + dropout=0.5 + ) + + if self.spk_kernels: + logging.info(f"Initialized speaker kernels for layers: {list(self.spk_kernels.keys())}") + self._attach_spk_kernel_hooks() + else: + logging.info("No speaker kernels initialized") + + def _attach_spk_kernel_hooks(self): + """ + Attach speaker kernel hooks to encoder layers. + Speaker kernels will inject the speaker information into the encoder layers. + """ + # Only attach hooks if not already attached + if hasattr(self, 'encoder_hooks'): + return + + self.encoder_hooks = [] + for layer_idx, kernel in self.spk_kernels.items(): + idx = int(layer_idx) + + if idx == 0: + hook = self.encoder.layers[idx].register_forward_pre_hook( + self._get_spk_kernel_hook_pre_layer(layer_idx), with_kwargs=True + ) + + if idx > 0: + # Attach a post-hook after each layer from 0 to 16. + # Since idx > 0, we attach to layer idx-1. + hook = self.encoder.layers[idx - 1].register_forward_hook( + self._get_spk_kernel_hook_post_layer(layer_idx) + ) + self.encoder_hooks.append(hook) + + def _get_spk_kernel_hook_pre_layer(self, layer_idx: str): + """ + Returns a hook function for applying speaker kernel transformation. + + Args: + layer_idx (str): Index of the layer to apply the kernel + + Returns: + callable: Hook function that applies speaker kernel + """ + + def hook_fn(module, args, kwargs): + # Pre-hooks with with_kwargs=True must return a (new_args, new_kwargs) tuple. + # The input tensor is passed as a keyword argument, so we find it in 'kwargs'. + + if 'x' in kwargs: + x = kwargs['x'] + x_spk = self.spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.spk_targets)) + # residual connection + x = x + x_spk + if self.add_bg_spk_kernel: + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x = x + x_bg_spk + kwargs['x'] = x + elif args: + # Fallback in case the call signature ever changes + x, *rest = args + x_spk = self.spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.spk_targets)) + # residual connection + x = x + x_spk + if self.add_bg_spk_kernel: + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x = x + x_bg_spk + args = (x, *rest) + + return args, kwargs + + return hook_fn + + def _get_spk_kernel_hook_post_layer(self, layer_idx: str): + """ + Returns a hook function for applying speaker kernel transformation. + + Args: + layer_idx (str): Index of the layer to apply the kernel + + Returns: + callable: Hook function that applies speaker kernel + """ + def hook_fn(module, input, output): + if self.spk_targets is None: + return output + + if isinstance(output, tuple): + x, *cache = output + else: + x = output + + x_spk = self.spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.spk_targets)) + # residual connection + x = x + x_spk + + if self.add_bg_spk_kernel: + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x = x + x_bg_spk + + if isinstance(output, tuple): + return (x, *cache) + return x + + return hook_fn + + def _cleanup_speaker_kernel_hooks(self): + """ + Clean up speaker kernel hooks to prevent memory leaks. + Can be called during model cleanup or when switching between modes. + """ + if hasattr(self, 'encoder_hooks'): + for hook in self.encoder_hooks: + try: + hook.remove() + except Exception as e: + logging.warning(f"Failed to remove speaker kernel hook: {e}") + delattr(self, 'encoder_hooks') + logging.info("Speaker kernel hooks cleaned up") + + def set_speaker_targets(self, spk_targets: Optional[torch.Tensor] = None, + bg_spk_targets: Optional[torch.Tensor] = None): + """ + Set speaker targets for the model. + + Args: + spk_targets: Main speaker targets tensor + bg_spk_targets: Background speaker targets tensor + """ + self.spk_targets = spk_targets + if self.add_bg_spk_kernel: + self.bg_spk_targets = bg_spk_targets + + def clear_speaker_targets(self): + """Clear speaker targets.""" + self.spk_targets = None + if self.add_bg_spk_kernel: + self.bg_spk_targets = None + + def solve_length_mismatch(self, x: torch.Tensor, mask: torch.Tensor): + """ + Solve length mismatch between x and mask. + """ + if mask is None: + mask = torch.ones_like(x[:, :, 0]) + logging.warning(f"Mask is None, triggering single speaker mode and assigning all ones with shape: {mask.shape}") + + if mask.shape[1] < x.shape[1]: + # pad zero to the left + mask = torch.nn.functional.pad(mask, (x.shape[1] - mask.shape[1], 0), mode='constant', value=1) + + if mask.shape[1] > x.shape[1]: + mask = mask[:, -x.shape[1]:] + + return mask + + def mask_with_speaker_targets(self, x: torch.Tensor, spk_targets: torch.Tensor): + """ + Mask the input with speaker targets. + """ + mask = self.solve_length_mismatch(x, spk_targets) + x_spk = x * mask.unsqueeze(2) + return x_spk + + def concat_with_speaker_targets(self, x: torch.Tensor, spk_targets: torch.Tensor): + """ + Concatenate the input with speaker targets. + """ + mask = self.solve_length_mismatch(x, spk_targets) + x_spk = x * mask.unsqueeze(2) + return x_spk + \ No newline at end of file diff --git a/nemo/collections/asr/parts/mixins/streaming.py b/nemo/collections/asr/parts/mixins/streaming.py index d6fd0b9b354b..7a2d921af115 100644 --- a/nemo/collections/asr/parts/mixins/streaming.py +++ b/nemo/collections/asr/parts/mixins/streaming.py @@ -47,6 +47,7 @@ def cache_aware_stream_step( cache_last_channel_len=None, keep_all_outputs=True, drop_extra_pre_encoded=None, + bypass_pre_encode=False, ): if self.streaming_cfg is None: self.setup_streaming_params() @@ -65,6 +66,7 @@ def cache_aware_stream_step( cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, cache_last_channel_len=cache_last_channel_len, + bypass_pre_encode=bypass_pre_encode, ) encoder_output = self.streaming_post_process(encoder_output, keep_all_outputs=keep_all_outputs) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 009a93b18d95..b0b75212f62c 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -11,13 +11,235 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import re import math -from typing import Optional, Union +import json +import random +import logging +import itertools +from copy import deepcopy +from cytoolz import groupby +import time +from collections import defaultdict + +import numpy as np +import soundfile +from tqdm import tqdm +from scipy.stats import norm + +import torch.utils.data +from lhotse.cut.set import mix +from lhotse.cut import Cut, CutSet, MixedCut, MonoCut, MixTrack +from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording +from lhotse.utils import uuid4, compute_num_samples, ifnone +from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator +from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator +from nemo.collections.asr.parts.utils.data_simulation_utils import read_rir_manifest +from typing import Optional, Union, List, Tuple, Dict, Any + +from omegaconf import OmegaConf +from dataclasses import dataclass, field +from typing import List, Optional + +import soundfile as sf +import os +@dataclass +class SessionConfig: + num_speakers: int = 1 + num_sessions: int = 1 + session_length: int = 15 + session_length_range: List[int] = field(default_factory=lambda: [10, 40]) + +@dataclass +class SessionParams: + max_audio_read_sec: float = 20.0 + sentence_length_params: List[float] = field(default_factory=lambda: [0.4, 0.05]) + dominance_var: float = 0.11 + min_dominance: float = 0.05 + turn_prob: float = 0.875 + min_turn_prob: float = 0.5 + mean_silence: float = 0.15 + mean_silence_var: float = 0.01 + per_silence_var: int = 900 + per_silence_min: float = 0.0 + per_silence_max: float = -1.0 + mean_overlap: float = 0.1 + mean_overlap_var: float = 0.01 + per_overlap_var: int = 900 + per_overlap_min: float = 0.0 + per_overlap_max: float = -1.0 + start_window: bool = True + window_type: str = "hamming" + window_size: float = 0.02 + start_buffer: float = 0.0 + split_buffer: float = 0.01 + release_buffer: float = 0.0 + normalize: bool = True + normalization_type: str = "equal" + normalization_var: float = 0.1 + min_volume: float = 0.75 + max_volume: float = 1.25 + end_buffer: float = 0.5 + random_offset: bool = True + +@dataclass +class OutputConfig: + output_dir: str = "" + output_filename: str = "multispeaker_session" + overwrite_output: bool = True + output_precision: int = 3 + +@dataclass +class BackgroundNoise: + add_bg: bool = True + background_manifest: Optional[str] = None + rir_manifest: Optional[str] = None + num_noise_files: int = 10 + snr: int = 60 + snr_min: Optional[float] = None + snr_max: Optional[float] = None + +@dataclass +class SegmentAugmentor: + add_seg_aug: bool = False + gain_prob: float = 0.5 + min_gain_dbfs: float = -10.0 + max_gain_dbfs: float = 10.0 + +@dataclass +class SessionAugmentor: + add_sess_aug: bool = False + white_noise_prob: float = 1.0 + min_white_noise_level: int = -90 + max_white_noise_level: int = -46 + +@dataclass +class SpeakerEnforcement: + enforce_num_speakers: bool = True + enforce_time: List[float] = field(default_factory=lambda: [0.25, 0.75]) + +@dataclass +class SegmentManifest: + window: float = 0.5 + shift: float = 0.25 + step_count: int = 50 + deci: int = 3 + +@dataclass +class RIRGeneration: + use_rir: bool = False + toolkit: str = "pyroomacoustics" + room_sz: List[List[int]] = field(default_factory=lambda: [[2, 3], [2, 3], [2, 3]]) + pos_src: List[List[List[float]]] = field(default_factory=lambda: [[[0.5, 1.5]] * 3] * 4) + noise_src_pos: List[float] = field(default_factory=lambda: [1.5, 1.5, 2]) + num_channels: int = 2 + pos_rcv: List[List[List[float]]] = field(default_factory=lambda: [[[0.5, 1.5]] * 3] * 2) + orV_rcv: Optional[List[List[float]]] = None + mic_pattern: str = "omni" + abs_weights: List[float] = field(default_factory=lambda: [0.9] * 6) + T60: float = 0.1 + att_diff: float = 15.0 + att_max: float = 60.0 + +@dataclass +class DataSimConfig: + """Configuration for data simulation.""" + manifest_filepath: str = "" + sr: int = 16000 + random_seed: int = 42 + multiprocessing_chunksize: int = 10000 + session_config: SessionConfig = field(default_factory=SessionConfig) + session_params: SessionParams = field(default_factory=SessionParams) + outputs: OutputConfig = field(default_factory=OutputConfig) + background_noise: BackgroundNoise = field(default_factory=BackgroundNoise) + background_manifest: str = "" + segment_augmentor: SegmentAugmentor = field(default_factory=SegmentAugmentor) + session_augmentor: SessionAugmentor = field(default_factory=SessionAugmentor) + speaker_enforcement: SpeakerEnforcement = field(default_factory=SpeakerEnforcement) + segment_manifest: SegmentManifest = field(default_factory=SegmentManifest) + rir_generation: RIRGeneration = field(default_factory=RIRGeneration) + +@dataclass +class MultiSpeakerSimulatorConfig: + data_simulator: DataSimConfig = field(default_factory=DataSimConfig) + +class Segment: + def __init__(self, start, end, speaker_id, text): + self.start = start + self.end = end + self.speaker_id = speaker_id + self.text = text + + def __str__(self): + return f"Segment(start={self.start}, end={self.end}, speaker_id={self.speaker_id}, text=\"{self.text}\")" + +class SegList: + def __init__(self, segments: List[Segment] = None, seglst_filepath: str = None): + if segments is not None: + self.segments = segments + elif seglst_filepath is not None: + self._load_seglst(seglst_filepath) + else: + raise ValueError("Either segments or seglst_filepath must be provided") + + def _load_seglst(self, seglst_filepath: str|list[str]): + if isinstance(seglst_filepath, str): + with open(seglst_filepath, 'r', encoding='utf-8') as f: + seglst = json.load(f) + self.segments = [ + Segment(seg['start_time'], seg['end_time'], seg['speaker'], seg['words']) for seg in seglst + ] + elif isinstance(seglst_filepath, list): + for seglst_file in seglst_filepath: + with open(seglst_file, 'r', encoding='utf-8') as f: + seglst = json.load(f) + segments = [ + Segment(seg['start_time'], seg['end_time'], seg['speaker'], seg['words']) for seg in seglst + ] + self.segments.extend(segments) + else: + raise ValueError("seglst_filepath must be a string or a list of strings") + self.sort() + + def __len__(self): + return len(self.segments) + + def __getitem__(self, idx): + return self.segments[idx] + + def __iter__(self): + return iter(self.segments) + + def sort(self): + self.segments.sort(key=lambda x: x.start) + + def get_segments(self, min_duration: float, max_duration: float): + + duration = random.uniform(min_duration, max_duration) + + first_segment_idx = random.randint(0, len(self) - 1) + segments = [self[first_segment_idx]] + + offset = self[first_segment_idx].start + for i in range(first_segment_idx + 1, len(self)): + if self[i].end - offset <= duration: + segments.append(self[i]) + else: + break + + return segments + + def get_text_from_segments(self, segments: list[Segment], speaker_token_style='<|spltoken*|>', speaker_token_position='sot'): + text = '' + speakers = set([segment.speaker_id for segment in segments]) + speaker2start = {spk_id: min(segment.start for segment in segments if segment.speaker_id == spk_id) for spk_id in speakers} + sorted_speakers = sorted(speakers, key=lambda x: speaker2start[x]) + speaker2token = {spk: speaker_token_style.replace('*', str(i)) for i, spk in enumerate(sorted_speakers)} + for segment in segments: + text += f'{speaker2token[segment.speaker_id]} ' + text += segment.text + return text.strip() -import torch -from lhotse import SupervisionSet -from lhotse.cut import MixedCut, MonoCut def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: @@ -195,7 +417,6 @@ def find_segments_from_rttm( end_before (float): The end time before which segments are selected. adjust_offset (bool): Whether to adjust the offset of the segments. tolerance (float): The tolerance for time matching. 0.001 by default. - Returns: segments (List[SupervisionSegment]): A list of SupervisionSegment instances. """ @@ -313,47 +534,38 @@ def get_hidden_length_from_sample_length( mel_frame_count = math.ceil(num_samples / num_sample_per_mel_frame) hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) - + def speaker_to_target( a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, - spk_tar_all_zero: bool = False, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, boundary_segments: bool = False, soft_label: bool = False, - ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, -): - """ - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape - (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. + ignore_num_spk_mismatch: bool = True, + return_text: bool = False, + ): + ''' + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) + This function is needed for speaker diarization with ASR model trainings. Args: - a_cut (MonoCut, MixedCut): - Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): - Max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): - Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): - Encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): - Set to True gives all zero "mask" - boundary_segments (bool): - Set to True to include segments containing the boundary of the cut, - False by default for multi-speaker ASR training - soft_label (bool): - Set to True to use soft label that enables values in [0, 1] range, - False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): - This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default + boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training + soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. + soft_thres (float): the threshold for the soft label, 0.5 by default. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + return_text (bool): set to True to return the text of the speakers (if it is available), False by default. + Returns: - mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) - """ + mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) + ''' # get cut-related segments from rttms + # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') if isinstance(a_cut, MixedCut): cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] @@ -362,19 +574,26 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] - for i, cut in enumerate(cut_list): - rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm( - recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 - ) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find( - recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True - ) + for i, cut in enumerate(cut_list): + if cut.custom.get('rttm_filepath', None): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + elif cut.supervisions: + rttms = SupervisionSet(cut.supervisions) + else: + logging.warning(f"No rttm or supervisions found for cut {cut.id}") + continue + + start = cut.offset if hasattr(cut, 'offset') else cut.start + end = start + cut.duration + recording_id = rttms[0].recording_id if len(rttms) > 0 else cut.recording_id + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm(recording_id=recording_id, rttms=rttms, start_after=start, end_before=end, tolerance=0.0) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find(recording_id=recording_id, start_after=start, end_before=end, adjust_offset=True) #, tolerance=0.0) + for seg in segments_iterator: if seg.start < 0: seg.duration += seg.start @@ -383,37 +602,293 @@ def speaker_to_target( seg.duration -= seg.end - cut.duration seg.start += offsets[i] segments_total.append(seg) - # apply arrival time sorting to the existing segments - segments_total.sort(key=lambda rttm_sup: rttm_sup.start) + segments_total.sort(key = lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} - if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError( - f"Number of speakers {len(speaker_to_idx_map)} is larger than " - f"the maximum number of speakers {num_speakers}" - ) - + + speaker_to_idx_map = { + spk: idx + for idx, spk in enumerate(speaker_ats) + } + num_speakers = len(speaker_ats) + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length( - a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame - ) - if spk_tar_all_zero: - frame_mask = torch.zeros((num_samples, num_speakers)) - else: - frame_mask = get_mask_from_segments( - segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch - ) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: mask = soft_mask else: mask = (soft_mask > soft_thres).float() + + if return_text: + speaker2text = defaultdict(list) + for seg in segments_total: + speaker2text[seg.speaker].append(seg.text) + texts = [' '.join(speaker2text[speaker]) for speaker in speaker_ats] + return mask, texts + else: + return mask - return mask +def read_seglst(seglst_filepath: str, session_id: Optional[str] = None): + """ + Read the seglst file and return a list of segments. + """ + with open(seglst_filepath, 'r', encoding='utf-8') as f: + seglst = json.load(f) + return [ + SupervisionSegment( + id=f'{seg["session_id"]}-sup{i:05d}', + recording_id=seg['session_id'] if session_id is None else session_id, + start=float(seg['start_time']), + duration=float(seg['end_time']) - float(seg['start_time']), + text=seg['words'], + speaker=seg['speaker'] + ) for i, seg in enumerate(seglst) + ] + +class MultiSpeakerMixtureGenerator(): + """ + This class is used to simulate multi-speaker audio data, + which can be used for multi-speaker ASR and speaker diarization training. + """ + def __init__( + self, + manifest_filepath, + sample_rate, + simulator_type, + min_duration=0.1, + max_duration=50.0, + min_delay=0.5, + random_seed=42, + num_speakers=2, + global_rank=0, + world_size=1, + ): + """ + Args: + cuts (CutSet): The cutset that contains single-speaker audio cuts. + Please make sure that the cuts have the 'speaker_id' attribute. + num_speakers (int): The number of speakers in the simulated audio. + We only simulate the samples with the fixed number of speakers. + The variation of the number of speakers is controlled by the weights in Lhotse dataloader config. + simulator_type (str): The type of simulator to use. + - 'lsmix': LibriSpeechMix-style training sample. + - 'meeting': Meeting-style training sample. + - 'conversation': Conversation-style training sample. + speaker_distribution (list): The distribution of speakers in the simulated audio. + The length of the list is the maximum number of speakers. + The list elements are the weights for each speaker. + min_delay (float): The minimum delay between speakers + to avoid the same starting time for multiple speakers. + """ + self.random_seed = random_seed + self.global_rank = global_rank + self.world_size = world_size + + self.manifest_filepath = manifest_filepath + self.manifests = list(LazyJsonlIterator(manifest_filepath)) + self.sample_rate = sample_rate + + self.min_duration = min_duration + self.max_duration = max_duration + self.min_delay = min_delay + self.simulator_type = simulator_type + self.max_speakers = num_speakers + + print("====== simulator_type", simulator_type) + + type2simulator = { + 'lsmix': self.LibriSpeechMixSimulator, + 'mixture_loader': self.MultiSpeakerMixtureLoader + } + + self.simulator = type2simulator[simulator_type] + + if simulator_type == 'lsmix': + self.spk2manifests = groupby(lambda x: x["speaker_id"], self.manifests) + self.speaker_ids = list(self.spk2manifests.keys()) + + self.count = 0 + + def __iter__(self): + return self + + def __next__(self): + self.count += 1 + return self.simulator() + + def LibriSpeechMixSimulator(self): + """ + This function simulates a LibriSpeechMix-style training sample. + Ref: + Paper: https://arxiv.org/abs/2003.12687 + Github: https://github.com/NaoyukiKanda/LibriSpeechMix + """ + # Sample the speakers + sampled_speaker_ids = random.sample(self.speaker_ids, self.max_speakers) + # Sample the cuts for each speaker + mono_cuts = [] + for speaker_id in sampled_speaker_ids: + manifest = random.choice(self.spk2manifests[speaker_id]) + mono_cuts.append(self._json_to_cut(manifest)) + mono_cuts[-1].supervisions.append( + SupervisionSegment( + id=uuid4(), + recording_id=uuid4(), + start=0.0, + duration=mono_cuts[-1].duration, + text=mono_cuts[-1].custom['text'], + speaker=speaker_id + ) + ) + + tracks = [] + offset = 0.0 + for speaker_id, mono_cut in zip(sampled_speaker_ids, mono_cuts): + tracks.append(MixTrack(cut=deepcopy(mono_cut), type=type(mono_cut), offset=offset)) + offset += random.uniform(self.min_delay, mono_cut.duration) + + mixed_cut = MixedCut(id='lsmix_' + '_'.join([track.cut.id for track in tracks]) + '_' + str(uuid4()), tracks=tracks) + + return mixed_cut + + def MultiSpeakerMixtureLoader(self): + """ + Load a multi-speaker mixture from the manifest, + and generate a mixed cut with a random duration. + The timestamps and transcript are from the seglst file, + where the format is: + { + "session_id": "session_id", + "speaker": "speaker_id", + "words": "transcript", + "start_time": "start_time", + "end_time": "end_time", + "duration": "duration" + ... + } + Supervisions are generated from the seglst file and sorted by start time. + """ + + manifest = random.choice(self.manifests) + audio_filepath = manifest['audio_filepath'] + seglst_filepath = manifest['seglst_filepath'] + + supervisions = read_seglst(seglst_filepath, session_id=manifest['session_id']) + supervisions = sorted(supervisions, key=lambda x: x.start) + + segment_offset, segment_duration = self._get_offset_and_duration(supervisions) + + json_dict = { + 'audio_filepath': audio_filepath, + 'duration': segment_duration, + 'offset': segment_offset, + 'supervisions': find_segments_from_rttm(recording_id=supervisions[0].recording_id, rttms=SupervisionSet(supervisions), start_after=segment_offset, end_before=segment_offset + segment_duration, adjust_offset=False) + } + cut = self._json_to_cut(json_dict) + + return cut + + def _get_offset_and_duration(self, supervisions): + """ + Get a random offset and duration of the segment. + supervisions should be sorted by start time + """ + non_overlap_supervisions_indices = self._get_non_overlap_supervisions_indices(supervisions) + # find the start and the end of the segment + start_idx = random.choice(non_overlap_supervisions_indices) + end_idx = start_idx + offset = supervisions[start_idx].start + for i in range(start_idx + 1, len(supervisions)): + end_idx = i + if supervisions[i].end - offset <= self.min_duration: + pass + else: + if i in non_overlap_supervisions_indices: + break + segment_offset = offset + segment_duration = supervisions[end_idx].end - offset + + return segment_offset, segment_duration + + def _get_non_overlap_supervisions_indices(self, supervisions): + """ + Get the indices of the non-overlapping supervisions. + supervisions should be sorted by start time + """ + non_overlap_supervisions_indices = [] + max_end = -1 + for i in range(len(supervisions)): + if supervisions[i].start >= max_end: + non_overlap_supervisions_indices.append(i) + max_end = max(max_end, supervisions[i].end) + return non_overlap_supervisions_indices + + def _json_to_cut(self, json_dict): + """ + Convert a json dictionary to a Cut instance. + """ + audio_path = json_dict["audio_filepath"] + duration = json_dict["duration"] + offset = json_dict.get("offset", 0.0) + supervisions = json_dict.get("supervisions", []) + cut = self._create_cut( + audio_path=audio_path, offset=offset, duration=duration, sampling_rate=json_dict.get("sampling_rate", None), + ) + # Note that start=0 and not start=offset because supervision's start if relative to the + # start of the cut; and cut.start is already set to offset + + if json_dict.get("text") is not None and json_dict.get("text") != "": + cut_text = json_dict.get("text") + else: + cut_text = " ".join(json_dict.get("words", [])) + if cut_text == " ": + cut_text = "" + + cut.supervisions.extend(supervisions) + cut.custom = json_dict + cut.duration = duration + return cut + + def _create_cut( + self, + audio_path: str, + offset: float, + duration: float, + sampling_rate: int | None = None, + channel: int = 0, + ) -> Cut: + + recording = self._create_recording(audio_path, duration, sampling_rate) + cut = recording.to_cut() + if isinstance(cut.channel, list) and len(cut.channel) > 1: + cut.channel = [channel] + if offset is not None: + cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) + cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" + return cut + + def _create_recording( + self, + audio_path: str, + duration: float, + sampling_rate: int | None = None, + ) -> Recording: + if sampling_rate is not None: + # TODO(pzelasko): It will only work with single-channel audio in the current shape. + return Recording( + id=audio_path, + sources=[AudioSource(type="file", channels=[0], source=audio_path)], + sampling_rate=sampling_rate, + num_samples=compute_num_samples(duration, sampling_rate), + duration=duration, + channel_ids=[0], + ) + else: + return Recording.from_file(audio_path) \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 66b21c2478a0..d90e8a604231 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -318,6 +318,7 @@ def get_background_noise( background_noise_snr: float, seed: int, device: torch.device, + sr: float = 16000, ): """ Augment with background noise (inserting ambient background noise up to the desired SNR for the full clip). @@ -343,9 +344,11 @@ def get_background_noise( power_array=power_array, snr_min=snr_min, snr_max=snr_max, background_noise_snr=background_noise_snr ) running_len_samples = 0 + noise_segment_list = [] + last_mixed_cut_offset = 0 + file_id = np.random.randint(len(noise_samples)) while running_len_samples < len_array: # build background audio stream (the same length as the full file) - file_id = np.random.randint(len(noise_samples)) audio_file, sr, audio_manifest = read_audio_from_buffer( audio_manifest=noise_samples[file_id], buffer_dict=audio_read_buffer_dict, @@ -353,6 +356,16 @@ def get_background_noise( device=device, read_subset=False, ) + # noise_segment_list.append( + noise_manifest_dict = copy.deepcopy(audio_manifest) + noise_manifest_dict['duration'] = float(min(len(audio_file), len_array - running_len_samples-1) / sr) + noise_manifest_dict['offset'] = 0 + noise_manifest_dict['volume'] = 1.0 + noise_manifest_dict['mixed_cut_offset'] = last_mixed_cut_offset + last_mixed_cut_offset += noise_manifest_dict['duration'] + + noise_segment_list.append(noise_manifest_dict) + if running_len_samples + len(audio_file) < len_array: end_audio_file = running_len_samples + len(audio_file) else: @@ -368,7 +381,7 @@ def get_background_noise( bg_array[running_len_samples:end_audio_file] = scaled_audio_file running_len_samples = end_audio_file - return bg_array, desired_snr + return bg_array, desired_snr, noise_segment_list def get_random_offset_index( @@ -442,12 +455,15 @@ def get_speaker_ids(sess_idx: int, speaker_samples: dict, permutated_speaker_ind speaker_ids (list): List of speaker IDs """ all_speaker_ids = list(speaker_samples.keys()) - idx_list = permutated_speaker_inds[sess_idx, :] + # Measure the length of permutated_speaker_inds and mod the sess_idx number so that + # sess_idx is always less than the length of permutated_speaker_inds + sess_idx_circular = sess_idx % permutated_speaker_inds.shape[0] + idx_list = permutated_speaker_inds[sess_idx_circular, :] speaker_ids = [all_speaker_ids[i] for i in idx_list] return speaker_ids -def build_speaker_samples_map(manifest: dict) -> dict: +def build_speaker_samples_map(manifest: dict, tqdm_bar: bool = False) -> dict: """ Build a dictionary for mapping speaker ID to their list of samples @@ -456,8 +472,9 @@ def build_speaker_samples_map(manifest: dict) -> dict: Dictionary mapping speaker ID to their list of samples """ speaker_samples = defaultdict(list) - logging.info("Building speaker to samples map...") - for sample in tqdm(manifest, total=len(manifest)): + # logging.info("Building speaker to samples map...") + for sample in tqdm(manifest, total=len(manifest), disable=not tqdm_bar): + # for sample in manifest: speaker_id = sample['speaker_id'] speaker_samples[speaker_id].append(sample) return speaker_samples @@ -492,6 +509,25 @@ def read_noise_manifest(add_bg: bool, background_manifest: str): return noise_manifest +def read_rir_manifest(rir_manifest: str): + """ + Read the rir manifest file and sample the rir manifest. + """ + # if isinstance(rir_manifest, str): + # rir_manifest_list = [rir_manifest] + # elif isinstance(rir_manifest, list): + # rir_manifest_list = rir_manifest + rir_manifest_list = [rir_manifest] + rir_loaded_list = [] + for manifest_file in rir_manifest_list: + # try: + if os.path.exists(manifest_file): + rir_loaded_list.extend(read_manifest(manifest_file)) + # except: + # import ipdb; ipdb.set_trace() + return rir_loaded_list + + def get_speaker_samples(speaker_ids: List[str], speaker_samples: dict) -> Dict[str, list]: """ Get a list of the samples for each of the specified speakers. @@ -587,6 +623,7 @@ def get_split_points_in_alignments( splits = [] for i in range(len(words)): if words[i] == "" and i != 0 and i != len(words) - 1: + # if words[i] == "" and i != 0 and i != len(words) - 1: silence_length = alignments[i] - alignments[i - 1] if silence_length > 2 * split_buffer: # split utterance on silence new_end = alignments[i - 1] + split_buffer @@ -659,7 +696,7 @@ def _init_file_write(self): Initialize file writing arguments """ self._file_base_str = "synthetic" - self._file_types = ["wav", "rttm", "json", "ctm", "txt", "meta"] + self._file_types = ["wav", "rttm", "json", "noise", "ctm", "txt", "meta"] self._annotation_types = ["rttm", "json", "ctm"] def _init_filelist_lists(self): @@ -678,7 +715,7 @@ def init_annotation_lists(self): self.annote_lists[file_type] = [] def create_new_rttm_entry( - self, words: List[str], alignments: List[float], start: int, end: int, speaker_id: int + self, words: List[str], alignments: List[float], start: int, end: int, speaker_id: int, add_split_buffer: bool = False ) -> List[str]: """ @@ -703,11 +740,19 @@ def create_new_rttm_entry( if ( silence_length > 2 * self._params.data_simulator.session_params.split_buffer ): # split utterance on silence - new_end = start + alignments[i - 1] + self._params.data_simulator.session_params.split_buffer + new_end = start + alignments[i - 1] + silence_duration = alignments[i] - alignments[i - 1] + + # new_end = start + alignments[i - 1] + self._params.data_simulator.session_params.split_buffer + + # import ipdb; ipdb.set_trace() + # if add_split_buffer: # add split buffer if specified in config + # new_end += self._params.data_simulator.session_params.split_buffer t_stt = round(float(new_start), self._params.data_simulator.outputs.output_precision) t_end = round(float(new_end), self._params.data_simulator.outputs.output_precision) rttm_list.append(f"{t_stt} {t_end} {speaker_id}") - new_start = start + alignments[i] - self._params.data_simulator.session_params.split_buffer + new_start = start + alignments[i] + # new_start = start + alignments[i] - self._params.data_simulator.session_params.split_buffer t_stt = round(float(new_start), self._params.data_simulator.outputs.output_precision) t_end = round(float(end), self._params.data_simulator.outputs.output_precision) @@ -753,9 +798,55 @@ def create_new_json_entry( "uem_filepath": None, } return meta + + def create_ctm_entry_from_segment_list( + self, source_segment_list, session_name: str, speaker_id: int, start: int + ) -> List[str]: + """ + Create new CTM entry (to write to output ctm file) + + Args: + words (list): List of words in the current audio file. + alignments (list): List of alignments (timestamps) for the current audio file. + session_name (str): Current session name. + speaker_id (int): LibriSpeech speaker ID for the current entry. + start (int): Current start of the audio file being inserted. + + Returns: + arr (list): List of ctm entries + """ + arr = [] + start = float(round(start, self._params.data_simulator.outputs.output_precision)) + + for seg_dict in source_segment_list: + words = seg_dict["words"] + alignments = seg_dict["alignments"] + start_offset = seg_dict["mixed_cut_offset"] + alignment_offset = alignments[0] + for i in range(len(words)): + word = words[i] + if ( + word != "" + ): # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 + # prev_align = 0 if i == 0 else alignments[i - 1] + # align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) + align1 = round(float(start_offset + alignments[i] - alignment_offset), self._params.data_simulator.outputs.output_precision) + align2 = round(float(start_offset + alignments[i+1] - alignment_offset - align1), self._params.data_simulator.outputs.output_precision) + text = get_ctm_line( + source=session_name, + channel=1, + start_time=align1, + duration=align2, + token=word, + conf=None, + type_of_token='lex', + speaker=speaker_id, + ) + arr.append((align1, text)) + return arr def create_new_ctm_entry( - self, words: List[str], alignments: List[float], session_name: str, speaker_id: int, start: int + self, words: List[str], alignments: List[float], session_name: str, speaker_id: int, start: int, ) -> List[str]: """ Create new CTM entry (to write to output ctm file) @@ -770,8 +861,9 @@ def create_new_ctm_entry( Returns: arr (list): List of ctm entries """ - arr = [] + arr, word_and_ts_list = [], [] start = float(round(start, self._params.data_simulator.outputs.output_precision)) + for i in range(len(words)): word = words[i] if ( @@ -779,7 +871,10 @@ def create_new_ctm_entry( ): # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) + # align1 = round(float(start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) + # align2 = round(float(alignments[i+1] - start), self._params.data_simulator.outputs.output_precision) + end_time = round(align1 + align2, self._params.data_simulator.outputs.output_precision) text = get_ctm_line( source=session_name, channel=1, @@ -789,9 +884,13 @@ def create_new_ctm_entry( conf=None, type_of_token='lex', speaker=speaker_id, + output_precision=self._params.data_simulator.outputs.output_precision, ) + # if word == "it" and align1 == 3.169: + # import ipdb; ipdb.set_trace() + word_and_ts_list.append((word, align1, end_time)) arr.append((align1, text)) - return arr + return arr, word_and_ts_list def add_to_filename_lists(self, basepath: str, filename: str): """ @@ -835,6 +934,20 @@ def write_annotation_files(self, basepath: str, filename: str, meta_data: dict): write_text(os.path.join(basepath, filename + '.txt'), self.annote_lists['ctm']) write_manifest(os.path.join(basepath, filename + '.meta'), [meta_data]) + def write_annotation_rttm_and_ctm(self, basepath: str, filename: str): + """ + Write all annotation files: RTTM, JSON, CTM, TXT, and META. + + Args: + basepath (str): Basepath for output files. + filename (str): Base filename for all output files. + meta_data (dict): Metadata for the current session. + rttm_list (list): List of RTTM entries. + json_list (list): List of JSON entries. + ctm_list (list): List of CTM entries. + """ + labels_to_rttmfile(self.annote_lists['rttm'], os.path.join(basepath, filename), self._params.data_simulator.outputs.output_dir) + write_ctm(os.path.join(basepath, filename + '.ctm'), self.annote_lists['ctm']) class SpeechSampler(object): """ @@ -1139,4 +1252,4 @@ def sample_noise_manifest(self, noise_manifest: dict) -> list: selected_noise_ids = np.random.choice(range(len(noise_manifest)), num_noise_files, replace=False) for k in selected_noise_ids: sampled_noise_manifest.append(noise_manifest[k]) - return sampled_noise_manifest + return sampled_noise_manifest \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index f0b951eb89d3..82f6a081c602 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -16,13 +16,14 @@ import csv import json import os -from collections import OrderedDict as od +from collections import defaultdict, OrderedDict as od from datetime import datetime -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np +from pyannote.metrics.diarization import DiarizationErrorRate -from nemo.collections.asr.metrics.der import concat_perm_word_error_rate +from nemo.collections.asr.metrics.der import concat_perm_word_error_rate, calculate_session_cpWER from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ClusteringDiarizer from nemo.collections.asr.parts.utils.speaker_utils import ( @@ -31,6 +32,8 @@ labels_to_rttmfile, rttm_to_labels, write_rttm2manifest, + generate_diarization_output_lines, + labels_to_pyannote_object, ) from nemo.utils import logging @@ -44,6 +47,24 @@ __all__ = ['OfflineDiarWithASR'] +def get_color_palette() -> Dict[str, str]: + return { + 'speaker_0': '\033[1;32m', + 'speaker_1': '\033[1;34m', + 'speaker_2': '\033[1;36m', + 'speaker_3': '\033[1;31m', + 'speaker_4': '\033[1;35m', + 'speaker_5': '\033[1;30m', + 'speaker_6': '\033[1;37m', + 'speaker_7': '\033[1;30m', + 'speaker_8': '\033[1;33m', + 'speaker_9': '\033[0;34m', + 'white': '\033[0;37m', + 'black': '\033[0;30m', + } + + + def dump_json_to_file(file_path: str, session_trans_dict: dict): """ Write a json file from the session_trans_dict dictionary. @@ -71,6 +92,33 @@ def write_txt(w_path: str, val: str): with open(w_path, "w") as output: output.write(val + '\n') +def init_session_trans_dict(uniq_id: str, n_spk: int): + """ + Initialize json (in dictionary variable) formats for session level result and Gecko style json. + + Returns: + (dict): Session level result dictionary variable + """ + return od( + { + 'status': 'initialized', + 'session_id': uniq_id, + 'transcription': '', + 'speaker_count': n_spk, + 'words': [], + 'sentences': [], + } + ) + +def init_session_gecko_dict(): + """ + Initialize a dictionary format for Gecko style json. + + Returns: + (dict): + Gecko style json dictionary. + """ + return od({'schemaVersion': 2.0, 'monologues': []}) def convert_ctm_to_text(ctm_file_path: str) -> Tuple[List[str], str]: """ @@ -187,6 +235,38 @@ def convert_word_dict_seq_to_ctm( return ctm_lines + def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_chars_in_line: int = 90) -> str: + """ + Break the lines in the transcript. + + Args: + string_out (str): + Input transcript with speaker labels + max_chars_in_line (int): + Maximum characters in each line + + Returns: + return_string_out (str): + String variable containing line breaking + """ + color_str_len = len('\033[1;00m') if self.params['colored_text'] else 0 + split_string_out = string_out.split('\n') + return_string_out = [] + for org_chunk in split_string_out: + buffer = [] + if len(org_chunk) - color_str_len > max_chars_in_line: + color_str = org_chunk[:color_str_len] if color_str_len > 0 else '' + for i in range(color_str_len, len(org_chunk), max_chars_in_line): + trans_str = org_chunk[i : i + max_chars_in_line] + if len(trans_str.strip()) > 0: + c_trans_str = color_str + trans_str + buffer.append(c_trans_str) + return_string_out.extend(buffer) + else: + return_string_out.append(org_chunk) + return_string_out = '\n'.join(return_string_out) + return return_string_out + def get_total_result_dict( der_results: Dict[str, Dict[str, float]], wer_results: Dict[str, Dict[str, float]], csv_columns: List[str], ): @@ -258,7 +338,553 @@ def get_num_of_spk_from_labels(labels: List[str]) -> int: spk_set = [x.split(' ')[-1].strip() for x in labels] return len(set(spk_set)) +def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_start_time=False): + """ + Read a seglst file and return the speaker & text information dictionary. + + Args: + seglst_filepath: path to the seglst file + seglst format: + [ + { + "session_id": "Bed008", + "words": "alright so i'm i should read all of these numbers", + "speaker": "me045", + "start_time": "53.814", + "end_time": "56.753" + } + ] + round_digits (int): number of digits to round the timestamps + return_rttm (bool): Whether to return RTTM lines + + Returns: + seglst_dict (dict): + A dictionary containing speaker and text information for each segment. + rttm_lines (list): + A list containing RTTM lines. + """ + rttm_lines = [] + seglst = [] + with open(seglst_filepath, 'r') as f: + seglst_lines = json.loads(f.read()) + + for idx, line in enumerate(seglst_lines): + spk, start, end = line['speaker'], float(line['start_time']), float(line['end_time']) + dur = round(end - start, round_digits) + + if return_rttm: + rttm_line_str = f'SPEAKER {line["session_id"]} 1 {start:.3f} {end-start:.3f} {spk} ' + rttm_lines.append(rttm_line_str) + seglst.append( + { + 'session_id': line['session_id'], + 'speaker': spk, + 'words': line['words'], + 'start_time': start, + 'end_time': end, + 'duration': dur, + } + ) + if sort_by_start_time: + seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) + if return_rttm: + return seglst, rttm_lines + return seglst + +def convert_seglst(seglst, all_speakers): + ''' + convert the seglst to a format that can be used for scoring + + Args: + seglst (list): list of seglst dictionaries + all_speakers (list): list of all active speakers + Returns: + timestamps: (list of list) + [ + [[st1, et1], [st2, et2]], # timestamps list for speaker 1 + [[st1, et1], ...], # timestamps list for speaker 2 + ...] + words (list[[s1], [s2], [s3], [s4]]): list of words for each speaker 1 to 4 + ''' + + timestamps = [[] for _ in all_speakers] + words = ['' for _ in all_speakers] + + spk2id = {spk: idx for idx, spk in enumerate(all_speakers)} + seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) + for seg in seglst: + timestamps[spk2id[seg['speaker']]].append((seg['start_time'], seg['end_time'])) + words[spk2id[seg['speaker']]] += seg['words'] + ' ' + + return timestamps, words + + +def get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels): + n_spk = get_num_of_spk_from_labels(diar_labels) + session_trans_dict = init_session_trans_dict(uniq_id=uniq_id, n_spk=n_spk) + gecko_dict = init_session_gecko_dict() + word_seq_list, audacity_label_words = [], [] + start_point, end_point, speaker = diar_labels[0].split() + prev_speaker = speaker + + sentences, terms_list = [], [] + sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} + + for k, word_dict in enumerate(word_dict_seq_list): + word, speaker = word_dict['word'], word_dict['speaker'] + word_seq_list.append(word) + start_point, end_point = word_dict['start_time'], word_dict['end_time'] + if speaker != prev_speaker: + if len(terms_list) != 0: + gecko_dict['monologues'].append( + {'speaker': {'name': None, 'id': prev_speaker}, 'terms': terms_list} + ) + terms_list = [] + + # remove trailing space in text + sentence['text'] = sentence['text'].strip() + + # store last sentence + sentences.append(sentence) + + # start construction of a new sentence + sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} + else: + # correct the ending time + sentence['end_time'] = end_point + + stt_sec, end_sec = start_point, end_point + terms_list.append({'start': stt_sec, 'end': end_sec, 'text': word, 'type': 'WORD'}) + + # add current word to sentence + sentence['text'] += word.strip() + ' ' + + audacity_label_words.append(get_audacity_label(word, stt_sec, end_sec, speaker)) + prev_speaker = speaker + + session_trans_dict['words'] = word_dict_seq_list + + # note that we need to add the very last sentence. + sentence['text'] = sentence['text'].strip() + sentences.append(sentence) + + # Speaker independent transcription + session_trans_dict['transcription'] = ' '.join(word_seq_list) + # add sentences to transcription information dict + session_trans_dict['sentences'] = sentences + gecko_dict['monologues'].append({'speaker': {'name': None, 'id': speaker}, 'terms': terms_list}) + return session_trans_dict, gecko_dict, audacity_label_words, sentences + + +def print_sentences(sentences: List[Dict[str, float]], + color_palette: Dict[str, str], + params: Dict[str, bool]) -> None: + """ + Print a transcript with speaker labels and timestamps. + + Args: + sentences (list): + List containing sentence-level dictionaries. + + Returns: + string_out (str): + String variable containing transcript and the corresponding speaker label. + """ + # init output + string_out = '' + # time_color = color_palette.get('black', '\033[0;30m') + time_color = color_palette.get('white', '\033[0;30m') + + for sentence in sentences: + # extract info + speaker = sentence['speaker'] + start_point = sentence['start_time'] + end_point = sentence['end_time'] + if 'text' in sentence: + text = sentence['text'] + elif 'words' in sentence: + text = sentence['words'] + else: + raise ValueError(f"text or words not in sentence: {sentence}") + + if params.get('colored_text', False): + color = color_palette.get(speaker, '\033[0;37m') + else: + color = '' + + # cast timestamp to the correct format + datetime_offset = 16 * 3600 + if float(start_point) > 3600: + time_str = '%H:%M:%S.%f' + else: + time_str = '%M:%S.%f' + start_point, end_point = max(float(start_point), 0), max(float(end_point), 0) + start_point_str = datetime.fromtimestamp(start_point - datetime_offset).strftime(time_str)[:-4] + end_point_str = datetime.fromtimestamp(end_point - datetime_offset).strftime(time_str)[:-4] + + if params.get('print_time', False): + time_str = f'[{start_point_str}-{end_point_str}] ' + else: + time_str = '' + + # string out concatenation + speaker = speaker.replace("speaker_", "[ Speaker-") + " ]" + string_out += f'{time_color}{time_str}{color}{speaker} {text}\n' + + return string_out + +def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_start_time=False, sort_by_end_time=False): + """ + Read a seglst file and return the speaker & text information dictionary. + + Args: + seglst_filepath: path to the seglst file + seglst format: + [ + { + "session_id": "Bed008", + "words": "alright so i'm i should read all of these numbers", + "speaker": "me045", + "start_time": "53.814", + "end_time": "56.753" + } + ] + round_digits (int): number of digits to round the timestamps + return_rttm (bool): Whether to return RTTM lines + + Returns: + seglst_dict (dict): + A dictionary containing speaker and text information for each segment. + rttm_lines (list): + A list containing RTTM lines. + """ + rttm_lines = [] + seglst = [] + with open(seglst_filepath, 'r') as f: + seglst_lines = json.loads(f.read()) + + for idx, line in enumerate(seglst_lines): + spk, start, end = line['speaker'], float(line['start_time']), float(line['end_time']) + dur = round(end - start, round_digits) + + if return_rttm: + rttm_line_str = f'SPEAKER {line["session_id"]} 1 {start:.3f} {end-start:.3f} {spk} ' + rttm_lines.append(rttm_line_str) + seglst.append( + { + 'session_id': line['session_id'], + 'speaker': spk, + 'words': line['words'], + 'start_time': start, + 'end_time': end, + 'duration': dur, + } + ) + if sort_by_start_time and sort_by_end_time: + raise ValueError("Cannot sort by both start and end time") + if sort_by_start_time: + seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) + if sort_by_end_time: + seglst = sorted(seglst, key=lambda x: (x['end_time'], x['start_time'])) + if return_rttm: + return seglst, rttm_lines + return seglst + +def convert_seglst(seglst, all_speakers): + ''' + convert the seglst to a format that can be used for scoring + + Args: + seglst (list): list of seglst dictionaries + all_speakers (list): list of all active speakers + Returns: + timestamps: (list of list) + [ + [[st1, et1], [st2, et2]], # timestamps list for speaker 1 + [[st1, et1], ...], # timestamps list for speaker 2 + ...] + words (list[[s1], [s2], [s3], [s4]]): list of words for each speaker 1 to 4 + ''' + + timestamps = [[] for _ in all_speakers] + words = ['' for _ in all_speakers] + + spk2id = {spk: idx for idx, spk in enumerate(all_speakers)} + seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) + for seg in seglst: + timestamps[spk2id[seg['speaker']]].append((seg['start_time'], seg['end_time'])) + words[spk2id[seg['speaker']]] += seg['words'] + ' ' + + return timestamps, words + +def chunk_seglst( + seglst : List[Dict], + chunk_size: float = 10.0 +): + ''' + Get chunked timestamps and words for each speaker + + Args: + seglst (list): list of seglst dictionaries + chunk_size (float): chunk size in seconds + + Returns: + chunk_id2timestamps (dict): dictionary of chunk_id to list of timestamps + speakers (set): set of all speakers + session_id (str): session id + ''' + chunk_id2timestamps = defaultdict(list) + speakers = set() + session_ids = set() + + for segment in seglst: + session_id = segment['session_id'] + start_time = segment['start_time'] + end_time = segment['end_time'] + + # Determine interval bounds + chunk_start = int(start_time // chunk_size) + chunk_end = int(end_time // chunk_size) + + # Split and assign the segment across overlapping intervals + words = segment['words'] + for chunk_idx in range(chunk_start, chunk_end + 1): + chunk_start_time = chunk_idx * chunk_size + chunk_end_time = (chunk_idx + 1) * chunk_size + + # Calculate the adjusted start and end times for the split segment + segment_start = max(start_time, chunk_start_time) + segment_end = min(end_time, chunk_end_time) + + # Create a split segment and add it to the corresponding interval + split_segment = { + 'session_id': session_id, + 'speaker': segment['speaker'], + 'words': words, + 'start_time': segment_start, + 'end_time': segment_end, + 'duration': segment_end - segment_start, + } + words = "" + chunk_id2timestamps[chunk_idx].append(split_segment) + speakers.add(segment['speaker']) + session_ids.add(session_id) + + assert len(session_ids) <= 1, "All segments should belong to the same session" + + if len(session_ids) == 0: + session_id = None + else: + session_id = session_ids.pop() + + return chunk_id2timestamps, speakers, session_id + +# def streaming_evaluation( +# ref_seglst: List[Dict], +# ref_rttm_labels: List[str], +# hyp_seglst: List[Dict], +# collar: float = 0.25, +# ignore_overlap: bool = False, +# verbose: bool = True, +# chunk_size: float = 10.0, +# ): +# """ +# Perform streaming evaluation of diarization and ASR for one session + +# Args: +# ref_seglst (list): list of reference seglst dictionaries +# hyp_seglst (list): list of hypothesis seglst dictionaries +# collar (float): collar for DER calculation +# ignore_overlap (bool): whether to ignore overlapping segments +# verbose (bool): whether to print verbose output +# chunk_size (float): how frequently to chunk and evaluate the session +# """ +# max_duration = max([seg['end_time'] for seg in ref_seglst + hyp_seglst]) +# max_idx = int(max_duration // chunk_size) + 1 + +# chunked_ref_seglst, ref_speakers, ref_session_id = chunk_seglst(ref_seglst, chunk_size=chunk_size) +# chunked_hyp_seglst, hyp_speakers, hyp_session_id = chunk_seglst(hyp_seglst, chunk_size=chunk_size) + +# if ref_session_id is None: +# ref_session_id = hyp_session_id + +# assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" + +# # Only care about the sessions in reference only +# session_id = ref_session_id +# ref_speaker_words = defaultdict(list) +# hyp_speaker_words = defaultdict(list) + +# der_metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) +# cpwer_metric = calculate_session_cpWER +# der_list, cpwer_list = [], [] +# for chunk_idx in range(max_idx): +# ref_seglst = chunked_ref_seglst[chunk_idx] +# hyp_seglst = chunked_hyp_seglst[chunk_idx] + +# if len(ref_speaker_words) == 0: +# ref_speaker_words = ['' for _ in ref_speakers] +# if len(hyp_speaker_words) == 0: +# hyp_speaker_words = ['' for _ in hyp_speakers] +# if self.ref_rttm_labels is not None: +# ref_labels = self.ref_rttm_labels +# else: +# ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) +# ref_labels = generate_diarization_output_lines(speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers)) +# hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) + +# hyp_labels = generate_diarization_output_lines(speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers)) +# reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) +# hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) + +# for idx, speaker in enumerate(ref_speakers): +# ref_speaker_words[idx] += ref_speaker_word[idx] +# for idx, speaker in enumerate(hyp_speakers): +# hyp_speaker_words[idx] += hyp_speaker_word[idx] + +# der_met = der_metric(reference, hypothesis) +# cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) + +# if verbose: +# logging.info(f"Session ID: {session_id} Chunk ID: {chunk_idx} from {chunk_idx*chunk_size}s to {(chunk_idx+1)*chunk_size}s") +# logging.info(f"DER: {abs(der_metric)*100:.2f}%, cpWER: {cpWER*100:.2f}%") + +# der_list.append(abs(der_metric) * 100) +# cpwer_list.append(cpWER) + +# return der_list, cpwer_list + +class OnlineEvaluation: + """ + A class designed for performing online evaluation of diarization and ASR. + + Attributes: + ref_seglst (list): + List of reference seglst dictionaries + hyp_seglst (list): + List of hypothesis seglst dictionaries + collar (float): + Collar for DER calculation + ignore_overlap (bool): + Whether to ignore overlapping segments + verbose (bool): + Whether to print verbose output + """ + + def __init__(self, + ref_seglst: List[Dict], + ref_rttm_labels: List[str], + hyp_seglst: Optional[List[Dict]] = None, + collar: float = 0.25, + ignore_overlap: bool = False, + verbose: bool = True, + ): + self.ref_seglst = ref_seglst + self.ref_rttm_labels = ref_rttm_labels + self.hyp_seglst = hyp_seglst + self.collar = collar + self.ignore_overlap = ignore_overlap + self.verbose = verbose + self.der_list = [] + self.cpwer_list = [] + # current index of the reference seglst + self.current_idx = 0 + + def evaluate_inloop(self, hyp_seglst, end_step_time=0.0): + """ + Evaluate the diarization and ASR performance at each step. + + Args: + hyp_seglst (list): list of hypothesis seglst dictionaries from start to end_step_time + end_step_time (float): end time of the current step + """ + is_update = False + if end_step_time > self.ref_seglst[self.current_idx]['end_time']: + self.current_idx += 1 + is_update = True + ref_seglst = self.ref_seglst[:self.current_idx] + der_cumul, cpwer_cumul= self.evaluate(ref_seglst, hyp_seglst, chunk_size=-1, verbose=False) + der, cpwer = der_cumul[-1], cpwer_cumul[-1] + if self.verbose: + logging.info(f"Session ID: {self.ref_seglst[0]['session_id']} from 0.0s to {end_step_time:.3f}s") + logging.info(f"DER: {der:.2f}%, cpWER: {cpwer:.2f}%") + self.der_list.append(der) + self.cpwer_list.append(cpwer) + else: + is_update = False + if len(self.der_list) > 0 and len(self.cpwer_list) > 0: + der, cpwer = self.der_list[-1], self.cpwer_list[-1] + else: + der, cpwer = 400.0, 100.0 + return der, cpwer, is_update + + def evaluate_outofloop(self, chunk_size=10.0): + """ + Evaluate the diarization and ASR performance for the entire session. + + Args: + chunk_size (float): chunk size in seconds, will report DER and cpWER from start and end of each chunk + """ + return self.evaluate(self.ref_seglst, self.hyp_seglst, chunk_size=chunk_size) + + def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): + max_duration = max([seg['end_time'] for seg in ref_seglst + hyp_seglst]) + if chunk_size == -1: + chunk_size = max_duration + 1 + max_idx = int(max_duration // chunk_size) + 1 + + chunked_ref_seglst, ref_speakers, ref_session_id = chunk_seglst(ref_seglst, chunk_size=chunk_size) + chunked_hyp_seglst, hyp_speakers, hyp_session_id = chunk_seglst(hyp_seglst, chunk_size=chunk_size) + + if hyp_session_id is None: + hyp_session_id = ref_session_id + + assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" + + # Only care about the sessions in reference only + session_id = ref_session_id + ref_speaker_words = defaultdict(list) + hyp_speaker_words = defaultdict(list) + + der_metric = DiarizationErrorRate(collar=2 * self.collar, skip_overlap=self.ignore_overlap) + cpwer_metric = calculate_session_cpWER + der_list, cpwer_list = [], [] + for chunk_idx in range(max_idx): + ref_seglst = chunked_ref_seglst[chunk_idx] + hyp_seglst = chunked_hyp_seglst[chunk_idx] + + if len(ref_speaker_words) == 0: + ref_speaker_words = ['' for _ in ref_speakers] + if len(hyp_speaker_words) == 0: + hyp_speaker_words = ['' for _ in hyp_speakers] + hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) + ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) + + ref_labels = generate_diarization_output_lines(speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers)) + hyp_labels = generate_diarization_output_lines(speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers)) + reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) + + for idx, speaker in enumerate(ref_speakers): + ref_speaker_words[idx] += ref_speaker_word[idx] + for idx, speaker in enumerate(hyp_speakers): + hyp_speaker_words[idx] += hyp_speaker_word[idx] + + der_instance = der_metric(reference, hypothesis) + # Normalize the text + for spk_idx in range(len(hyp_speaker_words)): + hyp_speaker_words[spk_idx] = hyp_speaker_words[spk_idx].translate(str.maketrans('', '', string.punctuation)).lower() + cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) + + if verbose: + logging.info(f"Session ID: {session_id} Chunk ID: {chunk_idx} from 0.0s to {(chunk_idx+1)*chunk_size}s") + logging.info(f"DER: {abs(der_metric)*100:.2f}%, cpWER: {cpWER*100:.2f}%") + der_list.append(abs(der_metric) * 100) + cpwer_list.append(cpWER * 100) + + return der_list, cpwer_list + class OfflineDiarWithASR: """ A class designed for performing ASR and diarization together. @@ -320,25 +946,9 @@ def __init__(self, cfg_diarizer): self.make_file_lists() - self.color_palette = self.get_color_palette() + self.color_palette = get_color_palette() self.csv_columns = self.get_csv_columns() - @staticmethod - def get_color_palette() -> Dict[str, str]: - return { - 'speaker_0': '\033[1;32m', - 'speaker_1': '\033[1;34m', - 'speaker_2': '\033[1;30m', - 'speaker_3': '\033[1;31m', - 'speaker_4': '\033[1;35m', - 'speaker_5': '\033[1;36m', - 'speaker_6': '\033[1;37m', - 'speaker_7': '\033[1;30m', - 'speaker_8': '\033[1;33m', - 'speaker_9': '\033[0;34m', - 'white': '\033[0;37m', - } - @staticmethod def get_csv_columns() -> List[str]: return [ @@ -387,34 +997,6 @@ def _load_realigning_LM(self): logging.info(f"Loading LM for realigning: {self.realigning_lm_params['arpa_language_model']}") return arpa.loadf(self.realigning_lm_params['arpa_language_model'])[0] - def _init_session_trans_dict(self, uniq_id: str, n_spk: int): - """ - Initialize json (in dictionary variable) formats for session level result and Gecko style json. - - Returns: - (dict): Session level result dictionary variable - """ - return od( - { - 'status': 'initialized', - 'session_id': uniq_id, - 'transcription': '', - 'speaker_count': n_spk, - 'words': [], - 'sentences': [], - } - ) - - def _init_session_gecko_dict(self): - """ - Initialize a dictionary format for Gecko style json. - - Returns: - (dict): - Gecko style json dictionary. - """ - return od({'schemaVersion': 2.0, 'monologues': []}) - def _save_VAD_labels_list(self, word_ts_dict: Dict[str, Dict[str, List[float]]]): """ Take the non_speech labels from logit output. The logit output is obtained from @@ -865,61 +1447,8 @@ def _make_json_output( ] } """ - word_seq_list, audacity_label_words = [], [] - start_point, end_point, speaker = diar_labels[0].split() - prev_speaker = speaker - - sentences, terms_list = [], [] - sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} - - n_spk = get_num_of_spk_from_labels(diar_labels) logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ") - session_trans_dict = self._init_session_trans_dict(uniq_id=uniq_id, n_spk=n_spk) - gecko_dict = self._init_session_gecko_dict() - - for k, word_dict in enumerate(word_dict_seq_list): - word, speaker = word_dict['word'], word_dict['speaker'] - word_seq_list.append(word) - start_point, end_point = word_dict['start_time'], word_dict['end_time'] - if speaker != prev_speaker: - if len(terms_list) != 0: - gecko_dict['monologues'].append( - {'speaker': {'name': None, 'id': prev_speaker}, 'terms': terms_list} - ) - terms_list = [] - - # remove trailing space in text - sentence['text'] = sentence['text'].strip() - - # store last sentence - sentences.append(sentence) - - # start construction of a new sentence - sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} - else: - # correct the ending time - sentence['end_time'] = end_point - - stt_sec, end_sec = start_point, end_point - terms_list.append({'start': stt_sec, 'end': end_sec, 'text': word, 'type': 'WORD'}) - - # add current word to sentence - sentence['text'] += word.strip() + ' ' - - audacity_label_words.append(get_audacity_label(word, stt_sec, end_sec, speaker)) - prev_speaker = speaker - - session_trans_dict['words'] = word_dict_seq_list - - # note that we need to add the very last sentence. - sentence['text'] = sentence['text'].strip() - sentences.append(sentence) - gecko_dict['monologues'].append({'speaker': {'name': None, 'id': speaker}, 'terms': terms_list}) - - # Speaker independent transcription - session_trans_dict['transcription'] = ' '.join(word_seq_list) - # add sentences to transcription information dict - session_trans_dict['sentences'] = sentences + session_trans_dict, gecko_dict, audacity_label_words, sentences = get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels) self._write_and_log(uniq_id, session_trans_dict, audacity_label_words, gecko_dict, sentences) return session_trans_dict @@ -1162,37 +1691,7 @@ def write_session_level_result_in_csv( except IOError: logging.info("I/O error has occurred while writing a csv file.") - def _break_lines(self, string_out: str, max_chars_in_line: int = 90) -> str: - """ - Break the lines in the transcript. - Args: - string_out (str): - Input transcript with speaker labels - max_chars_in_line (int): - Maximum characters in each line - - Returns: - return_string_out (str): - String variable containing line breaking - """ - color_str_len = len('\033[1;00m') if self.params['colored_text'] else 0 - split_string_out = string_out.split('\n') - return_string_out = [] - for org_chunk in split_string_out: - buffer = [] - if len(org_chunk) - color_str_len > max_chars_in_line: - color_str = org_chunk[:color_str_len] if color_str_len > 0 else '' - for i in range(color_str_len, len(org_chunk), max_chars_in_line): - trans_str = org_chunk[i : i + max_chars_in_line] - if len(trans_str.strip()) > 0: - c_trans_str = color_str + trans_str - buffer.append(c_trans_str) - return_string_out.extend(buffer) - else: - return_string_out.append(org_chunk) - return_string_out = '\n'.join(return_string_out) - return return_string_out def _write_and_log( self, @@ -1218,9 +1717,9 @@ def _write_and_log( List containing sentence dictionary """ # print the sentences in the .txt output - string_out = self.print_sentences(sentences) + string_out = print_sentences(sentences, color_palette=self.color_palette, params=self.params) if self.params['break_lines']: - string_out = self._break_lines(string_out) + string_out = break_transcript_lines(string_out, params=self.params) session_trans_dict["status"] = "success" ctm_lines_list = convert_word_dict_seq_to_ctm(session_trans_dict['words']) @@ -1256,51 +1755,4 @@ def print_errors(der_results: Dict[str, Dict[str, float]], wer_results: Dict[str \nWER : {wer_results['total']['average_WER']:.4f}" ) else: - logging.info(DER_info) - - def print_sentences(self, sentences: List[Dict[str, float]]): - """ - Print a transcript with speaker labels and timestamps. - - Args: - sentences (list): - List containing sentence-level dictionaries. - - Returns: - string_out (str): - String variable containing transcript and the corresponding speaker label. - """ - # init output - string_out = '' - - for sentence in sentences: - # extract info - speaker = sentence['speaker'] - start_point = sentence['start_time'] - end_point = sentence['end_time'] - text = sentence['text'] - - if self.params['colored_text']: - color = self.color_palette.get(speaker, '\033[0;37m') - else: - color = '' - - # cast timestamp to the correct format - datetime_offset = 16 * 3600 - if float(start_point) > 3600: - time_str = '%H:%M:%S.%f' - else: - time_str = '%M:%S.%f' - start_point, end_point = max(float(start_point), 0), max(float(end_point), 0) - start_point_str = datetime.fromtimestamp(start_point - datetime_offset).strftime(time_str)[:-4] - end_point_str = datetime.fromtimestamp(end_point - datetime_offset).strftime(time_str)[:-4] - - if self.params['print_time']: - time_str = f'[{start_point_str} - {end_point_str}] ' - else: - time_str = '' - - # string out concatenation - string_out += f'{color}{time_str}{speaker}: {text}\n' - - return string_out + logging.info(DER_info) \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py new file mode 100644 index 000000000000..fd47d5ae8cea --- /dev/null +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -0,0 +1,1806 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, json +from typing import Optional, List, Tuple, Dict, Any +from collections import OrderedDict +from copy import deepcopy + +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + +from nemo.collections.asr.parts.utils.diarization_utils import read_seglst, OnlineEvaluation +from nemo.utils import logging + +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel + +from nemo.collections.asr.parts.utils.speaker_utils import ( +audio_rttm_map as get_audio_rttm_map, +rttm_to_labels, +get_uniqname_from_filepath, +) +from nemo.collections.asr.parts.utils.diarization_utils import ( +print_sentences, +get_color_palette, +write_txt, +) +from nemo.collections.asr.data.audio_to_diar_label import get_frame_targets_from_rttm, extract_frame_info_from_rttm +from nemo.collections.asr.modules.sortformer_modules import StreamingSortformerState + + +from lhotse.dataset.collation import collate_matrices +import itertools + +import time +from functools import wraps + +def measure_eta(func): + """ + Measure the time taken to execute the function and print the ETA. + + Args: + func (callable): The function to measure the ETA of. + + Returns: + callable: The wrapped function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() # Record the start time + result = func(*args, **kwargs) # Execute the function + end_time = time.time() # Record the end time + eta = end_time - start_time # Calculate the elapsed time + logging.info(f"[ Step-{kwargs['step_num']} ] for '{func.__name__}': {eta:.4f} seconds") # Print the ETA + return result # Return the original function's result + return wrapper + +def get_multi_talker_samples_from_manifest(cfg, manifest_file: str, feat_per_sec: float, max_spks: int): + """ + Get the multi-talker samples from the manifest file and save it to a list named 'samples'. + Also, save the rttm mask matrix to a list named 'rttms_mask_mats'. + + Args: + cfg (DictConfig): The configuration object. + manifest_file (str): The path to the manifest file. + feat_per_sec (float): The number of features per second. + max_spks (int): The maximum number of speakers. + + Returns: + samples (list): The list of samples. + rttms_mask_mats (list): The list of rttm mask matrices. + """ + samples, rttms_mask_mats = [], [] + with open(manifest_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + item = json.loads(line) + if 'audio_filepath' not in item: + raise KeyError(f"Line {line_num}: 'audio_filepath' missing") + if 'duration' not in item: + raise KeyError(f"Line {line_num}: 'duration' missing") + samples.append(item) + if cfg.spk_supervision == "rttm": + rttm_path = samples[-1]['rttm_filepath'] + if not rttm_path: + raise ValueError(f"Line {line_num}: rttm_filepath required when spk_supervision='rttm'") + if not os.path.exists(rttm_path): + raise FileNotFoundError(f"Line {line_num}: RTTM file not found: {rttm_path}") + + with open(rttm_path, 'r', encoding='utf-8') as f: + rttm_lines = f.readlines() + rttm_timestamps, _ = extract_frame_info_from_rttm(0, samples[-1]['duration'], rttm_lines) + rttm_mat = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=0, + duration=samples[-1]['duration'], + round_digits=3, + feat_per_sec=round(float(1 / feat_per_sec), 2), + max_spks=max_spks, + ) + rttms_mask_mats.append(rttm_mat) + samples[-1]['duration'] = None + if 'offset' not in item: + samples[-1]['offset'] = 0 + + if len(rttms_mask_mats) > 0: + rttms_mask_mats = collate_matrices(rttms_mask_mats) + else: + rttms_mask_mats = None + return samples, rttms_mask_mats + + +def setup_diarization_model(cfg: DictConfig, map_location: Optional[str] = None) -> SortformerEncLabelModel: + """Setup model from cfg and return diarization model and model name for next step""" + if cfg.diar_model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.diar_model_path, + map_location=map_location, strict=False) + model_name = os.path.splitext(os.path.basename(cfg.diar_model_path))[0] + elif cfg.diar_model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, + map_location=map_location) + model_name = os.path.splitext(os.path.basename(cfg.diar_model_path))[0] + elif cfg.diar_pretrained_name.startswith("nvidia/"): + diar_model = SortformerEncLabelModel.from_pretrained(cfg.diar_pretrained_name) + model_name = os.path.splitext(os.path.basename(cfg.diar_pretrained_name))[0] + else: + raise ValueError("cfg.diar_model_path must end with.ckpt or.nemo!") + return diar_model, model_name + +def write_seglst(output_filepath: str, seglst_list: list) -> None: + """ + Write the segmentation list to a file. + + Args: + output_filepath (str): The path to the output file. + seglst_list (list): The list of segmentation lists. + """ + with open(output_filepath, "w", encoding="utf-8") as f: + f.write(json.dumps(seglst_list, indent=2) + "\n") + +def get_new_sentence_dict( + speaker: str, + start_time: float, + end_time: float, + text: str, + session_id: Optional[str] = None, +) -> dict: + """ + Get a new SegLST style sentence dictionary variable. + + Args: + speaker (str): The speaker of the sentence. + start_time (float): The start time of the sentence. + end_time (float): The end time of the sentence. + text (str): The text of the sentence. + session_id (Optional[str]): The session id of the sentence. + + Returns: + Dict[str, Any]: A new SegLST style sentence dictionary variable. + """ + return { + 'speaker': speaker, + 'start_time': start_time, + 'end_time': end_time, + 'words': text.lstrip(), + 'session_id': session_id, + } + + +def calc_drop_extra_pre_encoded(asr_model: SortformerEncLabelModel, step_num: int, pad_and_drop_preencoded: bool): + """ + Calculate the number of extra tokens to drop after the downsampling. + + Args: + asr_model (SortformerEncLabelModel): The ASR model. + step_num (int): The step number. + pad_and_drop_preencoded (bool): Whether to pad and drop the extra pre-encoded tokens. + + Returns: + int: The number of extra tokens to drop. + """ + # for the first step there is no need to drop any tokens + # after the downsampling as no caching is being used + if step_num == 0 and not pad_and_drop_preencoded: + return 0 + else: + return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + +def fix_frame_time_step( + cfg: Any, + new_tokens: List[str], + new_words: List[str], + frame_inds_seq: List[int] + ) -> List[int]: + """ + Adjust the frame indices sequence to match the length of new tokens. + + This function handles mismatches between the number of tokens and the frame indices sequence. + It adjusts the frame_inds_seq to ensure it has the same length as new_tokens. + + Args: + cfg (Any): Configuration object containing logging settings. + new_tokens (List[str]): List of new tokens. + new_words (List[str]): List of new words. + frame_inds_seq (List[int]): List of frame indices. + + Returns: + List[int]: Adjusted frame indices sequence. + """ + if len(new_tokens) != len(frame_inds_seq): + # Sometimes there is a mismatch in the number of tokens between the new tokens and the frame indices sequence. + if len(frame_inds_seq) > len(new_words): + # Get unique frame indices sequence + frame_inds_seq = list(OrderedDict.fromkeys(frame_inds_seq)) + if len(frame_inds_seq) < len(new_tokens): + deficit = len(new_tokens) - len(frame_inds_seq) + frame_inds_seq = [frame_inds_seq[0]] * deficit + frame_inds_seq + elif len(frame_inds_seq) > len(new_tokens): + deficit = len(frame_inds_seq) - len(new_tokens) + frame_inds_seq = frame_inds_seq[deficit:] + + elif len(frame_inds_seq) < len(new_tokens): + deficit = len(new_tokens) - len(frame_inds_seq) + frame_inds_seq = [frame_inds_seq[0]] * deficit + frame_inds_seq + if cfg.log: + logging.warning( + f"Length of new token sequence ({len(new_tokens)}) does not match" + f"the length of frame indices sequence ({len(frame_inds_seq)}). Skipping this chunk." + ) + return frame_inds_seq + +def get_simulated_softmax(cfg, speaker_sigmoid: torch.Tensor) -> torch.Tensor: + """ + Simulate the softmax operation for speaker diarization. + + Args: + cfg (Any): Configuration object containing diarization settings. + speaker_sigmoid (torch.Tensor): Speaker sigmoid values. + + Returns: + speaker_softmax (torch.Tensor): Speaker softmax values. + """ + if speaker_sigmoid.ndim != 1: + raise ValueError(f"Expected 1D tensor for speaker_sigmoid, got shape {speaker_sigmoid.shape}") + if speaker_sigmoid.shape[0] < cfg.max_num_of_spks: + raise ValueError(f"speaker_sigmoid size {speaker_sigmoid.shape[0]} < max_num_of_spks {cfg.max_num_of_spks}") + + speaker_sigmoid = torch.clamp(speaker_sigmoid, min=cfg.min_sigmoid_val, max=1) + sigmoid_sum = speaker_sigmoid.sum() + if sigmoid_sum == 0: + logging.warning("speaker_sigmoid sum is zero, returning uniform distribution") + speaker_softmax = torch.ones_like(speaker_sigmoid) / speaker_sigmoid.shape[0] + else: + speaker_softmax = speaker_sigmoid / sigmoid_sum + speaker_softmax = speaker_softmax.cpu() + speaker_softmax[cfg.max_num_of_spks:] = 0.0 + return speaker_softmax + +def get_word_dict_content_offline( + cfg: Any, + word: str, + word_index: int, + diar_pred_out: torch.Tensor, + time_stt_end_tuple: Tuple[int], + frame_len: float = 0.08 +) -> Dict[str, Any]: + """ + Generate a dictionary containing word information and speaker diarization results. + + This function processes a single word and its associated tokens to determine + the start and end frames, speaker, and other relevant information. + + Args: + cfg (Any): Configuration object containing diarization settings. + word (str): The word being processed. + word_index (int): Index of the word in the sequence. + diar_pred_out (torch.Tensor): Diarization prediction output stream. + time_stt_end_tuple (int): Local time step offset. + + frame_len (float, optional): Length of each frame in seconds. Defaults to 0.08. + + Returns: + Dict[str, Any]: A dictionary containing word information and diarization results. + """ + frame_stt, frame_end = time_stt_end_tuple + + # Edge Cases: Sometimes, repeated token indexs can lead to incorrect frame and speaker assignment. + if frame_stt == frame_end: + if frame_stt >= diar_pred_out.shape[0] - 1: + frame_stt, frame_end = (diar_pred_out.shape[1] - 1, diar_pred_out.shape[0]) + else: + frame_end = frame_stt + 1 + + # Get the speaker based on the frame-wise softmax probabilities. + stt_p, end_p = max((frame_stt + cfg.left_frame_shift), 0), (frame_end + cfg.right_frame_shift) + speaker_sigmoid = diar_pred_out[stt_p:end_p, :].mean(dim=0) + speaker_softmax = get_simulated_softmax(cfg, speaker_sigmoid) + + speaker_softmax[cfg.max_num_of_spks:] = 0.0 + spk_id = speaker_softmax.argmax().item() + stt_sec, end_sec = frame_stt * frame_len, frame_end * frame_len + word_dict = {"word": word, + "word_index": word_index, + 'frame_stt': frame_stt, + 'frame_end': frame_end, + 'start_time': round(stt_sec, 3), + 'end_time': round(end_sec, 3), + 'speaker': f"speaker_{spk_id}", + 'speaker_softmax': speaker_softmax} + return word_dict + +def get_word_dict_content_online( + cfg: Any, + word: str, + word_index: int, + diar_pred_out_stream: torch.Tensor, + token_group: List[str], + frame_inds_seq: List[int], + time_step_local_offset: int, + frame_len: float = 0.08 +) -> Dict[str, Any]: + """ + Generate a dictionary containing word information and speaker diarization results. + + This function processes a single word and its associated tokens to determine + the start and end frames, speaker, and other relevant information. + + Args: + cfg (Any): Configuration object containing diarization settings. + word (str): The word being processed. + word_index (int): Index of the word in the sequence. + diar_pred_out_stream (torch.Tensor): Diarization prediction output stream. + Dimensions: (num_frames, max_num_of_spks) + token_group (List[str]): Group of tokens associated with the word. + frame_inds_seq (List[int]): Sequence of frame indices. + time_step_local_offset (int): Local time step offset. + frame_len (float, optional): Length of each frame in seconds. Defaults to 0.08. + + Returns: + Dict[str, Any]: A dictionary containing word information and diarization results. + """ + _stt, _end = time_step_local_offset, time_step_local_offset + len(token_group) - 1 + if len(token_group) == 1: + frame_stt, frame_end = frame_inds_seq[_stt], frame_inds_seq[_stt] + 1 + else: + try: + frame_stt, frame_end = frame_inds_seq[_stt], frame_inds_seq[_end] + except IndexError: + frame_stt, frame_end = frame_inds_seq[_stt], frame_inds_seq[_stt] + 1 + + # Edge Cases: Sometimes, repeated token indexs can lead to incorrect frame and speaker assignment. + if frame_stt == frame_end: + if frame_stt >= diar_pred_out_stream.shape[0] - 1: + frame_stt, frame_end = (diar_pred_out_stream.shape[0] - 1, diar_pred_out_stream.shape[0]) + else: + frame_end = frame_stt + 1 + + # Get the speaker based on the frame-wise softmax probabilities. + stt_p, end_p = max((frame_stt + cfg.left_frame_shift), 0), (frame_end + cfg.right_frame_shift) + speaker_sigmoid = diar_pred_out_stream[stt_p:end_p, :].mean(dim=0) + speaker_softmax = get_simulated_softmax(cfg, speaker_sigmoid) + + speaker_softmax[cfg.max_num_of_spks:] = 0.0 + spk_id = speaker_softmax.argmax().item() + stt_sec, end_sec = frame_stt * frame_len, frame_end * frame_len + word_dict = {"word": word, + "word_index": word_index, + 'frame_stt': frame_stt, + 'frame_end': frame_end, + 'start_time': round(stt_sec, 3), + 'end_time': round(end_sec, 3), + 'speaker': f"speaker_{spk_id}", + 'speaker_softmax': speaker_softmax} + return word_dict + +def get_multitoken_words( + cfg, + word_and_ts_seq: Dict[str, List], + word_seq: List[str], + new_words: List[str], + fix_prev_words_count: int = 5 +) -> Dict[str, List]: + """ + Fix multi-token words that were not fully captured by the previous chunk window. + + This function compares the words in the current sequence with the previously processed words, + and updates any multi-token words that may have been truncated in earlier processing. + + Args: + cfg (DiarizationConfig): Configuration object containing verbose setting. + word_and_ts_seq (Dict[str, List]): Dictionary containing word sequences and timestamps. + word_seq (List[str]): List of all words processed so far. + new_words (List[str]): List of new words in the current chunk. + fix_prev_words_count (int, optional): Number of previous words to check. Defaults to 5. + + Returns: + Dict[str, List]: Updated word_and_ts_seq with fixed multi-token words. + """ + prev_start = max(0, len(word_seq) - fix_prev_words_count - len(new_words)) + prev_end = max(0, len(word_seq) - len(new_words)) + for ct, prev_word in enumerate(word_seq[prev_start:prev_end]): + if len(word_and_ts_seq["words"]) > fix_prev_words_count - ct: + saved_word = word_and_ts_seq["words"][-fix_prev_words_count + ct]["word"] + if len(prev_word) > len(saved_word): + if cfg.verbose: + logging.info(f"[Replacing Multi-token Word]: {saved_word} with {prev_word}") + word_and_ts_seq["words"][-fix_prev_words_count + ct]["word"] = prev_word + return word_and_ts_seq + +def append_word_and_ts_seq( + cfg: Any, + word_idx_offset: int, + word_and_ts_seq: Dict[str, Any], + word_dict: Dict[str, Any] +) -> tuple[int, Dict[str, Any]]: + """ + Append the word dictionary to the word and time-stamp sequence. + + This function updates the word_and_ts_seq dictionary by appending new word information + and managing the buffered words and speaker count. + + Args: + cfg (Any): Configuration object containing parameters like word_window. + word_idx_offset (int): The current word index offset. + word_and_ts_seq (Dict[str, Any]): Dictionary containing word sequences and related information. + word_dict (Dict[str, Any]): Dictionary containing information about the current word. + + Returns: + tuple[int, Dict[str, Any]]: A tuple containing the updated word_idx_offset and word_and_ts_seq. + """ + word_and_ts_seq["words"].append(word_dict) + word_and_ts_seq["buffered_words"].append(word_dict) + word_and_ts_seq["speaker_count_buffer"].append(word_dict["speaker"]) + word_and_ts_seq["word_window_seq"].append(word_dict['word']) + + if len(word_and_ts_seq["words"]) >= cfg.word_window + 1: + word_and_ts_seq["buffered_words"].pop(0) + word_and_ts_seq["word_window_seq"].pop(0) + word_idx_offset = 0 + + word_and_ts_seq["speaker_count"] = len(set(word_and_ts_seq["speaker_count_buffer"])) + return word_idx_offset, word_and_ts_seq + + +class SpeakerTaggedASR: + def __init__( + self, + cfg, + asr_model, + diar_model, + ): + # Required configs, models and datasets for inference + self.cfg = cfg + if self.cfg.manifest_file: + self.test_manifest_dict = get_audio_rttm_map(self.cfg.manifest_file) + elif self.cfg.audio_file is not None: + uniq_id = get_uniqname_from_filepath(filepath=self.cfg.audio_file) + self.test_manifest_dict = {uniq_id: {'audio_filepath': self.cfg.audio_file, 'seglst_filepath': None, 'rttm_filepath': None}} + else: + raise ValueError("One of the audio_file and manifest_file should be non-empty!") + + self.asr_model = asr_model + self.diar_model = diar_model + + # ASR speaker tagging configs + self._fix_prev_words_count = cfg.fix_prev_words_count + self._sentence_render_length = int(self._fix_prev_words_count + cfg.update_prev_words_sentence) + self._frame_len_sec = 0.08 + self._initial_steps = cfg.ignored_initial_frame_steps + self._stt_words = [] + self._init_evaluator() + self._frame_hop_length = self.asr_model.encoder.streaming_cfg.valid_out_len + + # Multi-instance configs + self._max_num_of_spks = cfg.get("max_num_of_spks", 4) + self._offset_chunk_start_time = 0.0 + self._sent_break_sec = cfg.get("sent_break_sec", 5.0) + + self._att_context_size = cfg.att_context_size + self._nframes_per_chunk = self._att_context_size[1] + 1 + self._cache_gating = cfg.get("cache_gating", False) + self._cache_gating_buffer_size = cfg.get("cache_gating_buffer_size", 2) + self._binary_diar_preds = cfg.binary_diar_preds + + self._masked_asr = cfg.get("masked_asr", True) + self._use_mask_preencode = cfg.get("mask_preencode", False) + + self.instance_manager = MultiTalkerInstanceManager( + asr_model=self.asr_model, + diar_model=self.diar_model, + max_num_of_spks=self.diar_model._cfg.max_num_of_spks, + batch_size=cfg.batch_size, + sent_break_sec=self._sent_break_sec, + ) + self.n_active_speakers_per_stream = self.cfg.max_num_of_spks + + def _init_evaluator(self): + """ + Initialize the evaluator for the offline STT and speaker diarization. + """ + self.online_evaluators, self._word_and_ts_seq = [], {} + for _, (uniq_id, data_dict) in enumerate(self.test_manifest_dict.items()): + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + self._word_and_ts_seq[uniq_id] = {"words": [], + "buffered_words": [], + "token_frame_index": [], + "offset_count": 0, + "status": "success", + "sentences": None, + "last_word_index": 0, + "speaker_count": None, + "transcription": None, + "max_spk_probs": [], + "word_window_seq": [], + "speaker_count_buffer": [], + "sentence_memory": {}, + } + + if 'seglst_filepath' in data_dict and data_dict['seglst_filepath'] is not None: + ref_seglst = read_seglst(data_dict['seglst_filepath']) + else: + ref_seglst = None + + if 'rttm_filepath' in data_dict and data_dict['rttm_filepath'] is not None: + ref_rttm_labels = rttm_to_labels(data_dict['rttm_filepath']) + else: + ref_rttm_labels = None + + eval_instance = OnlineEvaluation(ref_seglst=ref_seglst, + ref_rttm_labels=ref_rttm_labels, + hyp_seglst=None, + collar=0.25, + ignore_overlap=False, + verbose=True) + self.online_evaluators.append(eval_instance) + + def _get_offset_sentence(self, session_trans_dict: Dict[str, Any], offset: int) -> Dict[str, Any]: + """ + For the very first word in a session, get the offset sentence. + + Args: + session_trans_dict (dict): Dictionary containing session-related information. + offset (int): Index of the word for which the offset sentence is needed. + + Returns: + (Dict): Dictionary containing offset sentence information. + """ + word_dict = session_trans_dict['words'][offset] + return {'session_id': session_trans_dict['uniq_id'], + 'speaker': word_dict['speaker'], + 'start_time': word_dict['start_time'], + 'end_time': word_dict['end_time'], + 'words': f"{word_dict['word']} "} + + def _get_sentence(self, word_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Get the sentence for a given word. + + Args: + word_dict (Dict[str, Any]): Dictionary containing word-related information. + """ + return {'speaker': word_dict['speaker'], + 'start_time': word_dict['start_time'], + 'end_time': word_dict['end_time'], + 'words': ''} + + def get_sentences_values(self, session_trans_dict: dict, sentence_render_length: int): + """ + Get sentences (speaker-turn-level text) for a given session and sentence render length. + + Args: + session_trans_dict (Dict[str, Any]): Dictionary containing session-related information. + sentence_render_length (int): Length of the sentences to be generated. + + Returns: + sentences (List[Dict[str, Any]]): List of sentences in the session. + """ + stt_word_index = max(0, session_trans_dict['last_word_index'] - sentence_render_length) + if session_trans_dict['sentences'] is None: + sentence = self._get_offset_sentence(session_trans_dict=session_trans_dict, offset=0) + sentences = [] + session_trans_dict['last_word_index'] = stt_word_index + session_trans_dict['sentence_memory'].update({stt_word_index: + (deepcopy(sentences), + deepcopy(sentence), + sentence['speaker'] + )}) + prev_speaker = session_trans_dict['words'][stt_word_index]['speaker'] + else: + (_sentences, _sentence, prev_speaker) = session_trans_dict['sentence_memory'][stt_word_index] + sentences, sentence = deepcopy(_sentences), deepcopy(_sentence) + + for word_idx in range(stt_word_index + 1, len(session_trans_dict['words'])): + word_dict = session_trans_dict['words'][word_idx] + word, end_point = word_dict['word'], word_dict['end_time'] + if word_dict['speaker'] != prev_speaker: + sentence['words'] = sentence['words'].strip() + sentences.append(sentence) + sentence = self._get_sentence(word_dict=session_trans_dict['words'][word_idx]) + else: + sentence['end_time'] = end_point + sentence['words'] += word.strip() + ' ' + sentence['words'] = sentence['words'] + sentence['session_id'] = session_trans_dict['uniq_id'] + session_trans_dict['last_word_index'] = word_idx + prev_speaker = word_dict['speaker'] + session_trans_dict['sentence_memory'][word_idx] = (deepcopy(sentences), deepcopy(sentence), prev_speaker) + sentence['words'] = sentence['words'].strip() + sentences.append(sentence) + session_trans_dict['sentences'] = sentences + return session_trans_dict + + def merge_transcript_and_speakers( + self, + test_manifest_dict: dict, + asr_hypotheses: List[Hypothesis], + diar_pred_out: torch.Tensor + ) -> Tuple[List[str], Dict[str, Dict[str, Any]]]: + """ + Merge the transcript and speakers and generate real-time scripts if the config is set. + + Args: + test_manifest_dict (Dict): Dictionary containing test manifest data. + asr_hypotheses (List[Hypothesis]): List of ASR hypotheses. + diar_pred_out (torch.Tensor): Diarization prediction output stream. + + Returns: + transcribed_speaker_texts (List[str]): List of transcribed speaker texts. + self._word_and_ts_seq (Dict[str, Dict[str, Any]]): Dictionary of word-level dictionaries with uniq_id as key. + """ + transcribed_speaker_texts = [None] * len(test_manifest_dict) + + for idx, (uniq_id, _) in enumerate(test_manifest_dict.items()): + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + if not len( asr_hypotheses[idx].text) == 0: + # Get the word-level dictionaries for each word in the chunk + self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_offline(uniq_id=uniq_id, + diar_pred_out=diar_pred_out[idx].squeeze(0), + asr_hypothesis=asr_hypotheses[idx], + word_and_ts_seq=self._word_and_ts_seq[uniq_id], + ) + if len(self._word_and_ts_seq[uniq_id]["words"]) > 0: + self._word_and_ts_seq[uniq_id] = self.get_sentences_values(session_trans_dict=self._word_and_ts_seq[uniq_id], + sentence_render_length=self._sentence_render_length) + if self.cfg.generate_realtime_scripts: + transcribed_speaker_texts[idx] = \ + print_sentences(sentences=self._word_and_ts_seq[uniq_id]["sentences"], + color_palette=get_color_palette(), + params=self.cfg) + write_txt(f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), + transcribed_speaker_texts[idx].strip()) + return transcribed_speaker_texts, self._word_and_ts_seq + + def get_frame_and_words_offline( + self, + uniq_id: str, + diar_pred_out: torch.Tensor, + asr_hypothesis: Hypothesis, + word_and_ts_seq: Dict[str, Any], + ): + """ + Get the frame and words for each word in the chunk. + + Args: + uniq_id (str): The unique id of the chunk. + diar_pred_out (torch.Tensor): Diarization prediction output stream. + asr_hypothesis (Hypothesis): ASR hypothesis. + word_and_ts_seq (Dict[str, Any]): Pre-existing word-level dictionaries. + + Returns: + word_and_ts_seq (Dict[str, Any]): The updated word-level dictionaries with new words. + """ + word_and_ts_seq['uniq_id'] = uniq_id + + for word_index, hyp_word_dict in enumerate(asr_hypothesis.timestamp['word']): + time_stt_end_tuple=(hyp_word_dict['start_offset'], hyp_word_dict['end_offset']) + word_dict = get_word_dict_content_offline(cfg=self.cfg, + word=hyp_word_dict['word'], + word_index=word_index, + diar_pred_out=diar_pred_out, + time_stt_end_tuple=time_stt_end_tuple, + frame_len=self._frame_len_sec + ) + word_and_ts_seq["words"].append(word_dict) + word_and_ts_seq["speaker_count_buffer"].append(word_dict["speaker"]) + word_and_ts_seq["word_window_seq"].append(word_dict['word']) + + word_and_ts_seq["buffered_words"] = word_and_ts_seq["words"] + word_and_ts_seq["speaker_count"] = len(set(word_and_ts_seq["speaker_count_buffer"])) + return word_and_ts_seq + + def get_frame_and_words_online( + self, + uniq_id: str, + step_num: int, + diar_pred_out_stream: torch.Tensor, + previous_hypothesis: Hypothesis, + word_and_ts_seq: Dict[str, Any], + ): + """ + Get the frame and words for each word object in the chunk during streaming inference. + + Args: + uniq_id (str): The unique id of the chunk. + step_num (int): The step number of the chunk. + diar_pred_out_stream (torch.Tensor): The diarization prediction output stream. + previous_hypothesis (Hypothesis): The previous hypothesis. + word_and_ts_seq (Dict[str, Any]): The word and timestamp sequence. + + Returns: + word_and_ts_seq (Dict[str, Any]): The word and timestamp sequence. + """ + offset = step_num * self._frame_hop_length + word_seq = previous_hypothesis.text.split() + new_words = word_seq[word_and_ts_seq["offset_count"]:] + new_token_group = self.asr_model.tokenizer.text_to_tokens(new_words) + new_tokens = list(itertools.chain(*new_token_group)) + frame_inds_seq = (torch.tensor(previous_hypothesis.timestamp) + offset).tolist() + frame_inds_seq = fix_frame_time_step(self.cfg, new_tokens, new_words, frame_inds_seq) + min_len = min(len(new_words), len(frame_inds_seq)) + word_and_ts_seq['uniq_id'] = uniq_id + + min_len = min(len(new_words), len(frame_inds_seq)) + for idx in range(min_len): + word_and_ts_seq["token_frame_index"].append((new_tokens[idx], frame_inds_seq[idx])) + word_and_ts_seq["offset_count"] += 1 + + time_step_local_offset, word_idx_offset = 0, 0 + word_and_ts_seq = get_multitoken_words(cfg=self.cfg, + word_and_ts_seq=word_and_ts_seq, + word_seq=word_seq, + new_words=new_words, + fix_prev_words_count=self._fix_prev_words_count + ) + + # Get the FIFO queue preds to word_and_ts_seq + for local_idx, (token_group, word) in enumerate(zip(new_token_group, new_words)): + word_dict = get_word_dict_content_online(cfg=self.cfg, + word=word, + word_index= ( len(word_and_ts_seq["words"]) + local_idx), + diar_pred_out_stream=diar_pred_out_stream, + token_group=token_group, + frame_inds_seq=frame_inds_seq, + time_step_local_offset=time_step_local_offset, + frame_len=self._frame_len_sec + ) + # Count the number of speakers in the word window + time_step_local_offset += len(token_group) + word_idx_offset, word_and_ts_seq = append_word_and_ts_seq(cfg=self.cfg, + word_idx_offset=word_idx_offset, + word_and_ts_seq=word_and_ts_seq, + word_dict=word_dict) + return word_and_ts_seq + + def _add_speaker_transcriptions( + self, + transcriptions: list, + speaker_transcriptions: List[str], + word_and_ts_seq: Dict[str, Dict[str, Any]], + test_manifest_dict: dict + ) -> Tuple[List[Hypothesis], List[Hypothesis]]: + """ + Add speaker tagging into the transcriptions generated from an ASR model. + + Args: + transcriptions (Tuple[List[Hypothesis], List[Hypothesis]]): + Tuple containing the transcriptions and n-best transcriptions. + speaker_transcriptions (List[str]): + List of speaker transcriptions. + word_and_ts_seq (Dict[str, Dict[str, Any]]): + Dictionary of word-level dictionaries with uniq_id as key. + test_manifest_dict (dict): + Dictionary containing test manifest data. + + Returns: + Tuple[List[Hypothesis], List[Hypothesis]]: Tuple containing the updated transcriptions with speaker tags. + """ + trans_hyp, _ = transcriptions + for sess_idx, (uniq_id, _) in enumerate(test_manifest_dict.items()): + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + if speaker_transcriptions[sess_idx] is not None: + trans_hyp[sess_idx].text = speaker_transcriptions[sess_idx] + speaker_added_word_dicts = [] + for word_idx, trans_wdict in enumerate(trans_hyp[0].timestamp['word']): + trans_wdict_copy = deepcopy(trans_wdict) + trans_wdict_copy['speaker'] = word_and_ts_seq[uniq_id]['words'][word_idx]['speaker'] + speaker_added_word_dicts.append(trans_wdict_copy) + trans_hyp[sess_idx].timestamp['word'] = speaker_added_word_dicts + w_count, segment_list = 0, [] + for word_idx, trans_segdict in enumerate(trans_hyp[0].timestamp['segment']): + words = trans_segdict['segment'].split() + spk_vote_pool = [] + for word in words: + if word.lower() != word_and_ts_seq[uniq_id]['words'][w_count]['word'].lower(): + raise ValueError( + f"Word mismatch: '{word.lower()}' != '{word_and_ts_seq[uniq_id]['words'][w_count]['word'].lower()}' " + f"at session {sess_idx}, word count {w_count}." + ) + spk_int = int(word_and_ts_seq[uniq_id]['words'][w_count]['speaker'].split('_')[-1]) + spk_vote_pool.append(spk_int) + w_count += 1 + trans_segdict['speaker'] = f"speaker_{torch.mode(torch.tensor(spk_vote_pool), dim=0).values.item()}" + segment_list.append(trans_segdict) + trans_hyp[sess_idx].timestamp['segment'] = segment_list + transcriptions = (trans_hyp, trans_hyp) + return transcriptions + + def perform_offline_stt_spk(self, override_cfg: Dict[str, Any]): + """ + Perform offline STT and speaker diarization on the provided manifest file. + + Args: + override_cfg (dict): Override configuration parameters. + + Returns: + transcriptions (Tuple): Tuple containing the speaker-tagged transcripts. + """ + transcriptions = self.asr_model.transcribe( + audio=self.cfg.dataset_manifest, + override_config=override_cfg, + ) + best_hyp, _ = transcriptions + _, pred_tensors = self.diar_model.diarize(audio=self.cfg.manifest_file, + include_tensor_outputs=True) + speaker_transcriptions, word_and_ts_seq = self.merge_transcript_and_speakers( + test_manifest_dict=self.diar_model._diarize_audio_rttm_map, + asr_hypotheses=best_hyp, + diar_pred_out=pred_tensors + ) + transcriptions = self._add_speaker_transcriptions( + transcriptions=transcriptions, + speaker_transcriptions=speaker_transcriptions, + word_and_ts_seq=word_and_ts_seq, + test_manifest_dict=self.diar_model._diarize_audio_rttm_map, + ) + return transcriptions + + def generate_seglst_dicts_from_serial_streaming(self, samples: List[Dict[str, Any]]): + """ + Generate the seglst dictionary for SegLST format from serial streaming. + For SegLST format, the session_id is the name of the audio file + should not contain "." in the name. + + Args: + samples (List[Dict[str, Any]]): List of samples. + """ + # for _, word_ts_and_seq in enumerate(self._word_and_ts_seq): + for sample in samples: + uniq_id = get_uniqname_from_filepath(sample['audio_filepath']).split('.')[0] + word_ts_and_seq_dict = self._word_and_ts_seq[uniq_id] + for sentence_dict in word_ts_and_seq_dict['sentences']: + session_id = word_ts_and_seq_dict['uniq_id'].split('.')[0] + seglst_dict = get_new_sentence_dict( + speaker=sentence_dict['speaker'], + start_time=float(sentence_dict['start_time']), + end_time=float(sentence_dict['end_time']), + text=sentence_dict["words"], + session_id=session_id + ) + self.instance_manager.seglst_dict_list.append(seglst_dict) + + def generate_seglst_dicts_from_parallel_streaming(self, samples: List[Dict[str, Any]]): + """ + Generate the seglst dictionary for SegLST format from parallel streaming. + For SegLST format, the session_id is the name of the audio file + should not contain "." in the name. + + Args: + samples (List[Dict[str, Any]]): List of samples. + """ + self.instance_manager.previous_asr_states.extend(self.instance_manager.batch_asr_states) + for sample, asr_state in zip(samples, self.instance_manager.previous_asr_states): + audio_filepath = sample["audio_filepath"] + uniq_id = os.path.basename(audio_filepath).split('.')[0] + seglsts = [ + get_new_sentence_dict( + speaker=seg['speaker'], + start_time=seg['start_time'], + end_time=seg['end_time'], + text=seg['words'], + session_id=uniq_id + ) for seg in asr_state.seglsts + ] + seglsts = sorted(seglsts, key=lambda x: x['start_time']) + self.instance_manager.seglst_dict_list.extend(seglsts) + + def _find_active_speakers(self, diar_preds: torch.Tensor, n_active_speakers_per_stream: int) -> List[List[int]]: + """ + Find the active speakers from the diar prediction output. + + Args: + diar_preds (torch.Tensor): The diar prediction output. + n_active_speakers_per_stream (int): The number of active speakers per stream. + + Returns: + speaker_ids_list (List[List[int]]): The list of active speakers for each stream. + """ + if diar_preds.ndim != 3: + raise ValueError(f"diar_preds must be 3D (B, T, N), got shape {diar_preds.shape}") + if n_active_speakers_per_stream > diar_preds.shape[2]: + raise ValueError(f"n_active_speakers_per_stream ({n_active_speakers_per_stream}) " + f"> available speakers ({diar_preds.shape[2]})") + max_probs = torch.max(diar_preds, dim=1).values # (B, T, N) --> (B, N) + top_values, top_indices = torch.topk(max_probs, k=n_active_speakers_per_stream, dim=1) + masks = top_values > 0.5 + + speaker_ids_list = [] + for speaker_ids, mask in zip(top_indices, masks): + speaker_ids_list.append(sorted(speaker_ids[mask].tolist())) + return speaker_ids_list + + def forward_pre_encoded(self, audio_signal: torch.Tensor, length: torch.Tensor, drop_extra_pre_encoded: int=0) -> None: + """ + Forward the pre-encoded features through the ASR model. + + Args: + audio_signal (torch.Tensor): The audio signal. + length (torch.Tensor): The length of the audio signal. + drop_extra_pre_encoded (int): The number of extra pre-encoded tokens to drop. + + Returns: + audio_signal (torch.Tensor): The pre-encoded audio signal. + length (torch.Tensor): The length of the pre-encoded audio signal. + """ + audio_signal = torch.transpose(audio_signal, 1, 2) # (B, T, D) -> (B, D, T) + + audio_signal, length = self.asr_model.encoder.pre_encode(x=audio_signal, lengths=length) + length = length.to(torch.int64) + # `self.streaming_cfg` is set by setup_streaming_cfg(), called in the init + if drop_extra_pre_encoded: + audio_signal = audio_signal[:, drop_extra_pre_encoded :, :] + length = (length - drop_extra_pre_encoded).clamp(min=0) + return audio_signal, length + + def mask_features( + self, + chunk_audio: torch.Tensor, + mask: torch.Tensor, + threshold: float = 0.5, + mask_value: float = -16.6355 + ) -> torch.Tensor: + """ + Mask the features of the chunk audio. + + Args: + chunk_audio (torch.Tensor): The chunk audio. + mask (torch.Tensor): The mask. + threshold (float): The threshold for the mask. + mask_value (float): The value for the masked audio. + + Returns: + masked_chunk_audio (torch.Tensor): The masked chunk audio. + """ + if chunk_audio.ndim != 3: + raise ValueError(f"chunk_audio must be 3D (B, C, T), got {chunk_audio.ndim}D with shape {chunk_audio.shape}") + if mask.ndim != 2: + raise ValueError(f"mask must be 2D (B, T), got {mask.ndim}D with shape {mask.shape}") + if chunk_audio.shape[0] != mask.shape[0]: + raise ValueError(f"Batch size mismatch: chunk_audio={chunk_audio.shape[0]}, mask={mask.shape[0]}") + mask = (mask > threshold).float() + mask = mask.unsqueeze(-1).repeat(1, 1, 8).flatten(1, 2) + + if mask.shape[1] > chunk_audio.shape[2]: + logging.warning(f"Mask shape {mask.shape} is greater than chunk_audio shape {chunk_audio.shape}") + mask = mask[:, :chunk_audio.shape[2]] + elif mask.shape[1] < chunk_audio.shape[2]: + logging.warning(f"Mask shape {mask.shape} is less than chunk_audio shape {chunk_audio.shape}") + mask = torch.nn.functional.pad(mask, (chunk_audio.shape[2] - mask.shape[1], 0), mode='constant', value=0) + + masked_chunk_audio = chunk_audio * mask.unsqueeze(1) + masked_chunk_audio[torch.where(chunk_audio == 0)] = mask_value + + return masked_chunk_audio + + def mask_preencode(self, chunk_audio: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Mask the pre-encoded features of the chunk audio. + + Args: + chunk_audio (torch.Tensor): The chunk audio. + mask (torch.Tensor): The mask. + threshold (float): The threshold for the mask. + + Returns: + masked_chunk_audio (torch.Tensor): The masked chunk audio. + """ + mask = (mask > threshold).float() + + if mask.shape[1] > chunk_audio.shape[1]: + logging.warning(f"Mask shape {mask.shape} is greater than chunk_audio shape {chunk_audio.shape}") + mask = mask[:, :chunk_audio.shape[1]] + elif mask.shape[1] < chunk_audio.shape[1]: + logging.warning(f"Mask shape {mask.shape} is less than chunk_audio shape {chunk_audio.shape}") + mask = torch.nn.functional.pad(mask, (chunk_audio.shape[1] - mask.shape[1], 0), mode='constant', value=0) + + masked_chunk_audio = chunk_audio * mask.unsqueeze(-1) + + return masked_chunk_audio + + def get_diar_pred_out_stream(self, step_num): + """ + Get the diar prediction output stream for the given step number. + + Args: + step_num (int): the step number + + Returns: + new_diar_pred_out_stream (torch.Tensor): the diar prediction output stream for the given step number + new_chunk_preds (torch.Tensor): the diar prediction output stream for the given step number + """ + start_frame_idx = step_num * self._nframes_per_chunk + end_frame_idx = start_frame_idx + self._nframes_per_chunk + new_diar_pred_out_stream = self.diar_model.rttms_mask_mats[:, :end_frame_idx] + new_chunk_preds = new_diar_pred_out_stream[:, start_frame_idx:end_frame_idx] + return new_diar_pred_out_stream, new_chunk_preds + + @measure_eta + def perform_serial_streaming_stt_spk( + self, + step_num: int, + chunk_audio: torch.Tensor, + chunk_lengths: torch.Tensor, + is_buffer_empty: bool, + drop_extra_pre_encoded: int, + ): + """ + Perform the serial streaming inference. + Serial streaming inference deploys a single ASR model instance to transcribe multiple speakers in a chunk. + All the updates are done to the instance manager in a `SpeakerTaggedASR` class instance. + + Args: + step_num (int): The step number of the chunk. + chunk_audio (torch.Tensor): The chunk audio. + chunk_lengths (torch.Tensor): The length of the chunk audio. + is_buffer_empty (bool): Whether the buffer is empty. + drop_extra_pre_encoded (int): The number of extra pre-encoded tokens to drop. + """ + # Initialize the instance manager with the batch size of the chunk audio. + if step_num == 0: + self.instance_manager.reset(batch_size=chunk_audio.shape[0]) + self.instance_manager.to(chunk_audio.device) + + # This part exists for compatibility with the parallel streaming inference. + self.instance_manager.get_active_speakers_info( + active_speakers=[[0] for _ in range(chunk_audio.shape[0])], + chunk_audio=chunk_audio, + chunk_lengths=chunk_lengths, + ) + + ( + asr_pred_out_stream, + _, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypotheses, + ) = self.asr_model.conformer_stream_step( + processed_signal=chunk_audio, + processed_signal_length=chunk_lengths, + cache_last_channel=self.instance_manager.active_cache_last_channel, + cache_last_time=self.instance_manager.active_cache_last_time, + cache_last_channel_len=self.instance_manager.active_cache_last_channel_len, + previous_hypotheses=self.instance_manager.active_previous_hypotheses, + previous_pred_out=self.instance_manager.active_asr_pred_out_stream, + keep_all_outputs=is_buffer_empty, + drop_extra_pre_encoded=drop_extra_pre_encoded, + return_transcription=True, + ) + + if self.diar_model.rttms_mask_mats is None: + + new_streaming_state, diar_pred_out_stream = self.diar_model.forward_streaming_step( + processed_signal=chunk_audio.transpose(1, 2), + processed_signal_length=chunk_lengths, + streaming_state=self.instance_manager.diar_states.streaming_state, + total_preds=self.instance_manager.diar_states.diar_pred_out_stream, + drop_extra_pre_encoded=drop_extra_pre_encoded + ) + self.instance_manager.update_diar_state( + diar_pred_out_stream=diar_pred_out_stream, + previous_chunk_preds=diar_pred_out_stream[:, -self._nframes_per_chunk:], + diar_streaming_state=new_streaming_state + ) + else: + _, new_chunk_preds = self.get_diar_pred_out_stream(step_num) + diar_pred_out_stream = new_chunk_preds + + transcribed_speaker_texts = [None] * len(self.test_manifest_dict) + for idx, (uniq_id, _) in enumerate(self.test_manifest_dict.items()): + if not (len(previous_hypotheses[idx].text) == 0 and step_num <= self._initial_steps): + # Get the word-level dictionaries for each word in the chunk + self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_online(uniq_id=uniq_id, + step_num=step_num, + diar_pred_out_stream=diar_pred_out_stream[idx, :, :], + previous_hypothesis=previous_hypotheses[idx], + word_and_ts_seq=self._word_and_ts_seq[uniq_id], + ) + if len(self._word_and_ts_seq[uniq_id]["words"]) > 0: + self._word_and_ts_seq[uniq_id] = self.get_sentences_values(session_trans_dict=self._word_and_ts_seq[uniq_id], + sentence_render_length=self._sentence_render_length) + if self.cfg.generate_realtime_scripts: + transcribed_speaker_texts[idx] = \ + print_sentences(sentences=self._word_and_ts_seq[uniq_id]["sentences"], + color_palette=get_color_palette(), + params=self.cfg) + write_txt(f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), + transcribed_speaker_texts[idx].strip()) + + for batch_idx in range(chunk_audio.shape[0]): + self.instance_manager.update_asr_state( + batch_idx, + speaker_id=0, + cache_last_channel=cache_last_channel[:, batch_idx], + cache_last_time=cache_last_time[:, batch_idx], + cache_last_channel_len=cache_last_channel_len[batch_idx], + previous_hypotheses=previous_hypotheses[batch_idx], + previous_pred_out=asr_pred_out_stream[batch_idx] + ) + + + @measure_eta + def perform_parallel_streaming_stt_spk( + self, + step_num, + chunk_audio, + chunk_lengths, + is_buffer_empty, + drop_extra_pre_encoded, + ): + """ + Perform the parallel streaming inference. + Parallel streaming inference deploys multiple ASR model instances to transcribe multiple speakers in a chunk. + All the updates are done to the instance manager in a `SpeakerTaggedASR` class instance. + + Args: + step_num (int): The step number of the chunk. + chunk_audio (torch.Tensor): The chunk audio. + chunk_lengths (torch.Tensor): The length of the chunk audio. + is_buffer_empty (bool): Whether the buffer is empty. + drop_extra_pre_encoded (int): The number of extra pre-encoded tokens to drop. + """ + # Initialize the instance manager with the batch size of the chunk audio. + if step_num == 0: + self._offset_chunk_start_time = 0 + self.instance_manager.reset(batch_size=chunk_audio.shape[0]) + self.instance_manager.to(chunk_audio.device) + + + # Step 2: diarize or get GT rttms + if self.diar_model.rttms_mask_mats is None: + new_streaming_state, new_diar_pred_out_stream = self.diar_model.forward_streaming_step( + processed_signal=chunk_audio.transpose(1, 2), + processed_signal_length=chunk_lengths, + streaming_state=self.instance_manager.diar_states.streaming_state, + total_preds=self.instance_manager.diar_states.diar_pred_out_stream, + drop_extra_pre_encoded=drop_extra_pre_encoded + ) + new_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk:] + + else: + new_diar_pred_out_stream, new_chunk_preds = self.get_diar_pred_out_stream(step_num) + new_streaming_state = self.instance_manager.diar_states.streaming_state + + # Step 3: update diar states + self.instance_manager.update_diar_state( + diar_pred_out_stream=new_diar_pred_out_stream, + previous_chunk_preds=new_chunk_preds, + diar_streaming_state=new_streaming_state + ) + # Step 4: find active speakers + diar_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk*self._cache_gating_buffer_size:] + if self._cache_gating: + active_speakers = self._find_active_speakers(diar_chunk_preds, n_active_speakers_per_stream=self.n_active_speakers_per_stream) + else: + active_speakers = [list(range(self.n_active_speakers_per_stream)) for _ in range(chunk_audio.shape[0])] + + if (self._masked_asr and self._use_mask_preencode) or not self._masked_asr: + chunk_audio, chunk_lengths = self.forward_pre_encoded(chunk_audio, chunk_lengths, drop_extra_pre_encoded) + bypass_pre_encode = True + else: + bypass_pre_encode = False + + # Step 5: generate instance for active speakers + ( + active_chunk_audio, + active_chunk_lengths, + active_speaker_targets, + inactive_speaker_targets, + ) = self.instance_manager.get_active_speakers_info( + active_speakers=active_speakers, + chunk_audio=chunk_audio, + chunk_lengths=chunk_lengths, + ) + + # skip current chunk if no active speakers are found + if active_chunk_audio is None: + return + + # Step 6: + # 1. mask the non-active speakers for masked ASR + # 2. set speaker targets for multitalker ASR + if self._masked_asr: + if self._use_mask_preencode: + active_chunk_audio = self.mask_preencode(chunk_audio=active_chunk_audio, mask=active_speaker_targets) + else: + active_chunk_audio = self.mask_features(chunk_audio=active_chunk_audio, mask=active_speaker_targets) + else: + if self._binary_diar_preds: + active_speaker_targets = (active_speaker_targets > 0.5).float() + inactive_speaker_targets = (inactive_speaker_targets > 0.5).float() + self.asr_model.set_speaker_targets(active_speaker_targets, inactive_speaker_targets) + + # Step 7: ASR forward pass for active speakers + ( + pred_out_stream, + _, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypotheses, + ) = self.asr_model.conformer_stream_step( + processed_signal=active_chunk_audio, + processed_signal_length=active_chunk_lengths, + cache_last_channel=self.instance_manager.active_cache_last_channel, + cache_last_time=self.instance_manager.active_cache_last_time, + cache_last_channel_len=self.instance_manager.active_cache_last_channel_len, + keep_all_outputs=is_buffer_empty, + previous_hypotheses=self.instance_manager.active_previous_hypotheses, + previous_pred_out=self.instance_manager.active_asr_pred_out_stream, + drop_extra_pre_encoded=drop_extra_pre_encoded, + return_transcription=True, + bypass_pre_encode=bypass_pre_encode + ) + + # Step 8: update ASR states + active_id = 0 + for batch_idx, speaker_ids in enumerate(active_speakers): + for speaker_id in speaker_ids: + self.instance_manager.update_asr_state( + batch_idx, + speaker_id, + cache_last_channel[:, active_id], + cache_last_time[:, active_id], + cache_last_channel_len[active_id], + previous_hypotheses[active_id], + pred_out_stream[active_id] + ) + active_id += 1 + + # Step 9: update seglsts with timestamps + self.instance_manager.update_seglsts(offset=self._offset_chunk_start_time) + self._offset_chunk_start_time += self._nframes_per_chunk * self._frame_len_sec + + if self.cfg.generate_realtime_scripts: + for session_idx in self.cfg.print_sample_indices: + asr_state = self.instance_manager.batch_asr_states[session_idx] + transcribed_speaker_texts = print_sentences(sentences=asr_state.seglsts, + color_palette=get_color_palette(), + params=self.cfg) + write_txt(f'{self.cfg.print_path.replace(".sh", f"_{session_idx}.sh")}', transcribed_speaker_texts.strip()) + + +class MultiTalkerInstanceManager: + """ + For multi-talker inference, we need to manage the information per speaker. + Each sample in a batch can be considered as a multi-talker instance, + and each instance may contain multiple speakers, which is the real + batch size for inference. If there are at most N speakers and the batch + size is B, then the real batch size for inference is at most B * N. + """ + class ASRState: + """ + ASR state for each instance. + 1. In parallel mode, each instance handles each potential speaker. + 2. In serial mode, each instance handles one session. + + The goal of ASR-State class is to handle the ASR cache state between streaming steps. + The ASR-states required to perform streaming inference are all included in this class. + """ + def __init__( + self, + max_num_of_spks: int = 4, + frame_len_sec: float = 0.08, + sent_break_sec: float = 5.0 + ): + """ + Initialize the ASR-State class with the initial parameters. + + Args: + max_num_of_spks (int): The maximum number of speakers. + frame_len_sec (float): The length of the frame in seconds. + sent_break_sec (float): The minimum time gap between two sentences in seconds. + """ + # Initialize the ASR state with the initial parameters. + self.speakers: Optional[List[str]] = None + self.cache_last_channel = None + self.cache_last_time = None + self.cache_last_channel_len = None + self.previous_hypothesis = None + self.previous_pred_out = None + + self.max_num_of_spks = max_num_of_spks + + self._frame_len_sec = frame_len_sec + self._sent_break_sec = sent_break_sec + self._speaker_wise_sentences = {} + self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] + + self.seglsts = [] + + def _reset_speaker_wise_sentences(self): + """ + Reset the speaker-wise sentences which will be used to generate the SegLST transcription outputs. + """ + self._speaker_wise_sentences = {} + self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] + + + def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + """ + Reset the ASR state. + + Args: + asr_cache_state (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The ASR cache state. + - cache_last_channel (torch.Tensor): The cache last channel. + - cache_last_time (torch.Tensor): The cache last time. + - cache_last_channel_len (torch.Tensor): The cache last channel length. + """ + self.speakers = [0] + self.cache_last_channel, self.cache_last_time, self.cache_last_channel_len = asr_cache_state + self.previous_hypothesis = [None] + self.previous_pred_out = [None] + self.seglsts = [] + self._speaker_wise_sentences = {} + self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] + + def update_asr_state(self, + speaker_id, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypothesis, + previous_pred_out + ): + """ + Update the ASR state with the new ASR cache state. + This function should be called at every streaming step to update the ASR cache state. + + Args: + speaker_id (int): The speaker id. + cache_last_channel (torch.Tensor): The cache last channel. + cache_last_time (torch.Tensor): The cache last time. + cache_last_channel_len (torch.Tensor): The cache last channel length. + previous_hypothesis (Hypothesis): The previous hypothesis. + previous_pred_out (torch.Tensor): The previous prediction output. + """ + self.cache_last_channel[:, speaker_id] = cache_last_channel + self.cache_last_time[:, speaker_id] = cache_last_time + self.cache_last_channel_len[speaker_id] = cache_last_channel_len + self.previous_hypothesis[speaker_id] = previous_hypothesis + self.previous_pred_out[speaker_id] = previous_pred_out + + def to(self, device): + """ + Override the to method to move the ASR state to the device. + + Args: + device (torch.device): The device to move the ASR state to. + """ + self.cache_last_channel = self.cache_last_channel.to(device) + self.cache_last_time = self.cache_last_time.to(device) + self.cache_last_channel_len = self.cache_last_channel_len.to(device) + + def get_speakers(self): + """ + Get the speaker ids (int) for each instance. + This function is used for serial streaming mode. + """ + return self.speakers + + def add_speaker(self, speaker_id: int, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + """ + Add a speaker index and its initial cache state to the ASR state. + + Args: + speaker_id (int): The speaker id. + asr_cache_state (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The ASR cache state. + """ + self.speakers.append(speaker_id) + cache_last_channel, cache_last_time, cache_last_channel_len = asr_cache_state + self.cache_last_channel = torch.cat([self.cache_last_channel, cache_last_channel], dim=1) + self.cache_last_time = torch.cat([self.cache_last_time, cache_last_time], dim=1) + self.cache_last_channel_len = torch.cat([self.cache_last_channel_len, cache_last_channel_len], dim=0) + self.previous_hypothesis.append(None) + self.previous_pred_out.append(None) + + def _update_last_sentence(self, spk_idx: int, end_time: float, diff_text: str): + """ + Update the end time of the last sentence for a speaker. + + Args: + spk_idx (int): The speaker id. + end_time (float): The end time of the last sentence. + diff_text (str): The difference text. + """ + if end_time is not None: + self._speaker_wise_sentences[spk_idx][-1]['end_time'] = end_time + new_words = self._speaker_wise_sentences[spk_idx][-1]['words'] + diff_text + self._speaker_wise_sentences[spk_idx][-1]['words'] = new_words.strip() + + def _is_new_text(self, spk_idx: int, text: str): + """ + Check if the text is new for a speaker. + + Args: + spk_idx (int): The speaker id. + text (str): The text. + """ + if text is None or text == self._prev_history_speaker_texts[spk_idx]: + return None + else: + # Get the difference between the current text and the previous text + if self._prev_history_speaker_texts[spk_idx] in text: + return text.replace(self._prev_history_speaker_texts[spk_idx], "") + else: + return text.strip() + + def _compute_hypothesis_timestamps( + self, + hypothesis: Hypothesis, + offset: float + ) -> Tuple[float, float, bool]: + """ + Compute start and end timestamps for a hypothesis based on available timing information. + + This method calculates the temporal boundaries of a speech hypothesis, prioritizing + frame-level timestamps when available. When timestamps are not available, it falls + back to computing timing based on the hypothesis length. + + Args: + hypothesis (Hypothesis): The ASR hypothesis object containing either frame-level + offset (float): The time offset (in seconds) to add to the computed timestamps, + typically representing the start time of the current audio chunk. + + Returns: + Tuple[float, float, bool]: A tuple containing: + - start_time (float): The absolute start time of the hypothesis in seconds + - end_time (float): The absolute end time of the hypothesis in seconds + - sep_flag (bool): A flag indicating whether timing was computed from length + rather than timestamps. + + Note: + The end_time calculation from timestamps adds 1 to the last timestamp to account + for the full duration of the final frame. + """ + sep_flag = False + if len(hypothesis.timestamp) > 0: + start_time = offset + (hypothesis.timestamp[0]) * self._frame_len_sec + end_time = offset + (hypothesis.timestamp[-1] + 1) * self._frame_len_sec + else: + start_time = offset + end_time = offset + hypothesis.length.item() * self._frame_len_sec + sep_flag = True + + return start_time, end_time, sep_flag + + def update_sessionwise_seglsts_for_parallel(self, offset: float): + """ + Update the seglsts for the parallel mode streaming. + Note that this function is NOT used for serial mode streaming. + + Args: + offset (float): The offset in seconds. + This is usally the start time of the current audio chunk. + """ + valid_speakers = set() + for spk_idx in self.get_speakers(): + hypothesis = self.previous_hypothesis[spk_idx] + if hypothesis is None: + continue + valid_speakers.add(spk_idx) + + if spk_idx not in self._speaker_wise_sentences: + self._speaker_wise_sentences[spk_idx] = [] + + diff_text = self._is_new_text(spk_idx=spk_idx, text=hypothesis.text) + if diff_text is not None: + + start_time, end_time, sep_flag = self._compute_hypothesis_timestamps( + hypothesis=hypothesis, + offset=offset + ) + + # Get the last end time of the previous sentence or None if no sentences are present + if len(self._speaker_wise_sentences[spk_idx]) > 0: + last_end_time = self._speaker_wise_sentences[spk_idx][-1]['end_time'] + else: + last_end_time = 0.0 + + # Case 1 - If start_tiime is greater than end_time + sent_break_sec, then we need to add the sentence + if sep_flag or (last_end_time == 0.0 or start_time > last_end_time + self._sent_break_sec): + if len(diff_text) > 0 and diff_text.strip()[0] in ['.', ',', '?', '!']: + # This handles the case where the first character should be assigned to the previous sentence. + the_first_char, diff_text = diff_text.strip()[0], diff_text.strip()[1:] + self._update_last_sentence(spk_idx=spk_idx, end_time=None, diff_text=the_first_char) + self._speaker_wise_sentences[spk_idx].append(get_new_sentence_dict(speaker=f"speaker_{spk_idx}", + start_time=start_time, + end_time=end_time, + text=diff_text)) + # Case 2 - If start_time is less than end_time + sent_break_sec, then we need to update the end_time + else: + self._update_last_sentence(spk_idx=spk_idx, end_time=end_time, diff_text=diff_text) + + # Update the previous history of the speaker text + if hypothesis.text is not None: + self._prev_history_speaker_texts[spk_idx] = hypothesis.text + + self.seglsts = [] + + # Merge all sentences for each speaker but sort by start_time + for spk_idx in valid_speakers: + self.seglsts.extend(self._speaker_wise_sentences[spk_idx]) + + # Finally, sort the seglsts by start_time + self.seglsts = sorted(self.seglsts, key=lambda x: x['start_time']) + + class DiarState: + """ + Diar state for each diarization instance. + There is no difference between serial and parallel mode for the diarization state. + The goal of Diar-State class is to handle the diarization cache state between streaming steps. + """ + def __init__(self, batch_size: int=1, max_num_of_spks: int=4): + """ + Initialize the Diar-State class with the initial parameters. + + Args: + batch_size (int): The batch size. + max_num_of_spks (int): The maximum number of speakers. + """ + self.batch_size = batch_size + self.max_num_of_spks = max_num_of_spks + self.diar_pred_out_stream = None + self.previous_chunk_preds = None + self.streaming_state = None + + def reset(self, diar_streaming_state: StreamingSortformerState): + self.diar_pred_out_stream = torch.zeros((self.batch_size, 0, self.max_num_of_spks)) + self.previous_chunk_preds = torch.zeros((self.batch_size, 0, self.max_num_of_spks)) + self.streaming_state = diar_streaming_state + + def to(self, device): + self.diar_pred_out_stream = self.diar_pred_out_stream.to(device) + self.previous_chunk_preds = self.previous_chunk_preds.to(device) + self.streaming_state.to(device) + + def __init__(self, + asr_model=None, + diar_model=None, + batch_size: int=1, + max_num_of_spks: int=4, + sent_break_sec: float=5.0, + ): + """ + Initialize the MultiTalkerInstanceManager class with the initial parameters. + + Args: + asr_model: The ASR model. + diar_model: The diarization model. + batch_size (int): The batch size for ASR. + 1. For parallel mode, this is the number of potential speakers + multiplied by the session counts. + 2. For serial mode, this is the number of sessions. + max_num_of_spks (int): The maximum number of speakers. + """ + self.asr_model = asr_model + self.diar_model = diar_model + + self.batch_size = batch_size + self.max_num_of_spks = max_num_of_spks + self._sent_break_sec = sent_break_sec + + # ASR state bank + self.batch_asr_states = [] + self.previous_asr_states = [] + + # Diar states + self.diar_states = None + + # SegLST output list + self.seglst_dict_list = [] + + # Active speaker buffer lists + self._active_chunk_audio: List[torch.Tensor] = [] + self._active_chunk_lengths: List[torch.Tensor] = [] + self._active_speaker_targets: List[torch.Tensor] = [] + self._inactive_speaker_targets: List[torch.Tensor] = [] + self._active_previous_hypotheses: List[Hypothesis] = [] + self._active_asr_pred_out_stream: List[torch.Tensor] = [] + self._active_cache_last_channel: List[torch.Tensor] = [] + self._active_cache_last_time: List[torch.Tensor] = [] + self._active_cache_last_channel_len: List[torch.Tensor] = [] + + # Active speaker attributes + self.active_previous_hypotheses: Optional[List[Hypothesis]] = None + self.active_asr_pred_out_stream: Optional[List[torch.Tensor]] = None + self.active_cache_last_channel: Optional[torch.Tensor] = None + self.active_cache_last_time: Optional[torch.Tensor] = None + self.active_cache_last_channel_len: Optional[torch.Tensor] = None + + def _reset_active_speaker_buffers(self): + """ + Reset the active speaker buffers need to update the active speaker information. + """ + self._active_chunk_audio = [] + self._active_chunk_lengths = [] + self._active_speaker_targets = [] + self._inactive_speaker_targets = [] + self._active_previous_hypotheses = [] + + self._active_asr_pred_out_stream = [] + self._active_cache_last_channel = [] + self._active_cache_last_time = [] + self._active_cache_last_channel_len = [] + + def reset(self, batch_size: Optional[int] = None, max_num_of_spks: Optional[int] = None): + """ + Reset the active speaker buffers need to update the active speaker information. + + Args: + batch_size (Optional[int]): The batch size. + max_num_of_spks (Optional[int]): The maximum number of speakers. + """ + if batch_size is not None: + self.batch_size = batch_size + if max_num_of_spks is not None: + self.max_num_of_spks = max_num_of_spks + + if len(self.batch_asr_states) > 0: + self.previous_asr_states.extend(deepcopy(self.batch_asr_states)) + self.batch_asr_states = [self.ASRState(self.max_num_of_spks, sent_break_sec=self._sent_break_sec) for _ in range(self.batch_size)] + + for i in range(self.batch_size): + self.batch_asr_states[i].reset(self.asr_model.encoder.get_initial_cache_state(batch_size=1)) + + self.diar_states = self.DiarState(batch_size=self.batch_size, max_num_of_spks=self.max_num_of_spks) + self.diar_states.reset(self.diar_model.sortformer_modules.init_streaming_state(batch_size=self.batch_size)) + + self.seglst_dict_list = [] + + def add_speaker(self, batch_idx: int, speaker_id: int): + """ + Add a speaker index and its initial cache state to the ASR state. + + Args: + batch_idx (int): The batch index. + speaker_id (int): The speaker id. + """ + speakers = self.batch_asr_states[batch_idx].get_speakers() + for speaker_index in range(0, speaker_id+1): + if speaker_index not in speakers: + self.batch_asr_states[batch_idx].add_speaker( + speaker_id=speaker_index, + asr_cache_state=self.asr_model.encoder.get_initial_cache_state(batch_size=1) + ) + + def get_speakers(self, batch_idx: int): + """ + Get the speaker ids (int) for each instance. + + Args: + batch_idx (int): The batch index. + """ + return self.batch_asr_states[batch_idx].get_speakers() + + def to(self, device: torch.device): + """ + Override the to method to move the ASR and Diar states to the device. + + Args: + device (torch.device): The device to move the ASR and Diar states to. + """ + for batch_idx in range(self.batch_size): + self.batch_asr_states[batch_idx].to(device) + self.diar_states.to(device) + + def update_diar_state( + self, + diar_pred_out_stream: torch.Tensor, + previous_chunk_preds: torch.Tensor, + diar_streaming_state: StreamingSortformerState + ): + """ + Update the diarization state from the diarization step. + The diarization results are updated as a form of torch.Tensor. + + Args: + diar_pred_out_stream (torch.Tensor): The diarization prediction output stream. + previous_chunk_preds (torch.Tensor): The previous chunk prediction output. + diar_streaming_state (StreamingSortformerState): The diarization streaming state. + """ + self.diar_states.diar_pred_out_stream = diar_pred_out_stream + self.diar_states.previous_chunk_preds = previous_chunk_preds + self.diar_states.streaming_state = diar_streaming_state + + def update_asr_state( + self, + batch_idx, + speaker_id, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypotheses, + previous_pred_out + ): + """ + A function to update the ASR state with the new ASR cache state. + This function should be called at every streaming step to update the ASR cache state. + + Args: + batch_idx (int): The batch index. + If parallel mode, this is the index of the potential speaker. + If serial mode, this is the index of the session. + speaker_id (int): The speaker id in the given session. + -- Cache aware ASR related parameters -- + cache_last_channel (torch.Tensor) + cache_last_time (torch.Tensor) + cache_last_channel_len (torch.Tensor) + previous_hypotheses (Hypothesis) + previous_pred_out (torch.Tensor) The previous prediction output. + """ + self.batch_asr_states[batch_idx].update_asr_state( + speaker_id, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypotheses, + previous_pred_out + ) + + def get_active_speakers_info(self, active_speakers, chunk_audio, chunk_lengths): + """ + Collect the active speaker information for the next streaming step and + update the active speaker buffers. + + Args: + active_speakers (List[List[int]]): The active speakers for each chunk. + chunk_audio (torch.Tensor): The chunk audio. + chunk_lengths (torch.Tensor): The chunk lengths. + """ + # Reset the active speaker buffers + self._reset_active_speaker_buffers() + + # Loop through the active speakers and update the active speaker buffers + for batch_idx, speaker_ids in enumerate(active_speakers): + for speaker_id in speaker_ids: + self._active_chunk_audio.append(chunk_audio[batch_idx, :]) + self._active_chunk_lengths.append(chunk_lengths[batch_idx]) + self._active_speaker_targets.append(self.diar_states.previous_chunk_preds[batch_idx, :, speaker_id]) + inactive_speaker_ids = [i for i in range(len(speaker_ids)) if i != speaker_id] + self._inactive_speaker_targets.append((self.diar_states.previous_chunk_preds[batch_idx, :, inactive_speaker_ids] > 0.5).sum(dim=-1) > 0) + if speaker_id not in self.batch_asr_states[batch_idx].get_speakers(): + self.add_speaker(batch_idx, speaker_id) + + self._active_previous_hypotheses.append(self.batch_asr_states[batch_idx].previous_hypothesis[speaker_id]) + self._active_asr_pred_out_stream.append(self.batch_asr_states[batch_idx].previous_pred_out[speaker_id]) + self._active_cache_last_channel.append(self.batch_asr_states[batch_idx].cache_last_channel[:, speaker_id]) + self._active_cache_last_time.append(self.batch_asr_states[batch_idx].cache_last_time[:, speaker_id]) + self._active_cache_last_channel_len.append(self.batch_asr_states[batch_idx].cache_last_channel_len[speaker_id]) + if len(self._active_chunk_audio) == 0: + return None, None, None, None + + # Convert chunk audio and target info to tensors + active_chunk_audio = torch.stack(self._active_chunk_audio) + active_chunk_lengths = torch.stack(self._active_chunk_lengths) + active_speaker_targets = torch.stack(self._active_speaker_targets) + inactive_speaker_targets = torch.stack(self._inactive_speaker_targets) + + # Update active speaker attributes + self.active_previous_hypotheses = deepcopy(self._active_previous_hypotheses) + self.active_asr_pred_out_stream = deepcopy(self._active_asr_pred_out_stream) + self.active_cache_last_channel = torch.stack(self._active_cache_last_channel).transpose(0, 1) + self.active_cache_last_time = torch.stack(self._active_cache_last_time).transpose(0, 1) + self.active_cache_last_channel_len = torch.stack(self._active_cache_last_channel_len) + return active_chunk_audio, active_chunk_lengths, active_speaker_targets, inactive_speaker_targets + + def update_seglsts(self, offset: int): + """ + Take the ASR states and update the seglsts. + + Args: + offset (int): The offset of the chunk. + """ + for asr_state in self.batch_asr_states: + asr_state.update_sessionwise_seglsts_for_parallel(offset=offset) \ No newline at end of file diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 781f22207d64..7bbbf109b48e 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -42,6 +42,7 @@ TextTurn, ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]: @@ -815,6 +816,25 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: return cuts, is_tarred +@data_type_parser("multi_speaker_simulator") +def read_multi_speaker_simulator(config: DictConfig) -> tuple[CutSet, bool]: + multi_speaker_cuts = CutSet( + MultiSpeakerMixtureGenerator( + manifest_filepath=config.manifest_filepath, + simulator_type=config.simulator_type, + sample_rate=config.get("sample_rate", 16000), + min_delay=config.get("min_delay", 0.5), + min_duration=config.get("min_duration", 0.1), + max_duration=config.get("max_duration", 60), + num_speakers=config.get("num_speakers", 2), + global_rank=config.get("global_rank", 0), + world_size=config.get("world_size", 1), + ) + ) + is_tarred = config.get("is_tarred", False) + return multi_speaker_cuts, is_tarred + + def mux( *cutsets: CutSet, weights: list[Union[int, float]], From fd9219144cd23cde534957f47e037fd9ecdcebba Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Thu, 9 Oct 2025 01:47:13 +0000 Subject: [PATCH 02/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- ...ech_to_text_multitalker_streaming_infer.py | 141 +++-- .../asr/data/audio_to_text_lhotse_speaker.py | 41 +- nemo/collections/asr/models/__init__.py | 2 +- .../asr/models/multitalker_asr_models.py | 25 +- .../asr/models/sortformer_diar_models.py | 8 +- nemo/collections/asr/parts/mixins/__init__.py | 2 +- .../parts/mixins/multitalker_asr_mixins.py | 71 +-- .../collections/asr/parts/mixins/streaming.py | 3 +- .../asr/parts/utils/asr_multispeaker_utils.py | 192 ++++--- .../asr/parts/utils/data_simulation_utils.py | 147 +++-- .../asr/parts/utils/diarization_utils.py | 163 +++--- .../parts/utils/multispk_transcribe_utils.py | 542 +++++++++--------- nemo/collections/common/data/lhotse/cutset.py | 4 +- 13 files changed, 735 insertions(+), 606 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 937d1c29d1fd..c2a7453db7ff 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -12,50 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json -from dataclasses import dataclass, is_dataclass, field -from typing import Optional, Union, List, Tuple, Dict, Any - -import torch +import math import os +import time +from collections import OrderedDict +from copy import deepcopy +from dataclasses import dataclass, field, is_dataclass +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple, Union + import pytorch_lightning as pl -from omegaconf import OmegaConf -from omegaconf import open_dict +import torch from lhotse.dataset.collation import collate_matrices - +from omegaconf import OmegaConf, open_dict import nemo.collections.asr as nemo_asr -from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis -from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer - -from copy import deepcopy -from nemo.collections.asr.parts.utils.diarization_utils import read_seglst, OnlineEvaluation -from nemo.utils import logging - +from nemo.collections.asr.data.audio_to_diar_label import extract_frame_info_from_rttm, get_frame_targets_from_rttm from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel -from nemo.core.config import hydra_runner - -from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR, get_multi_talker_samples_from_manifest -from nemo.collections.asr.parts.utils.speaker_utils import ( -audio_rttm_map as get_audio_rttm_map, -rttm_to_labels, -) from nemo.collections.asr.parts.utils.diarization_utils import ( -print_sentences, -get_color_palette, -write_txt, + OnlineEvaluation, + get_color_palette, + print_sentences, + read_seglst, + write_txt, ) -from nemo.collections.asr.data.audio_to_diar_label import get_frame_targets_from_rttm, extract_frame_info_from_rttm - - -from typing import List, Optional -from dataclasses import dataclass -from collections import OrderedDict -import itertools +from nemo.collections.asr.parts.utils.multispk_transcribe_utils import ( + SpeakerTaggedASR, + get_multi_talker_samples_from_manifest, +) +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map as get_audio_rttm_map +from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels +from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer +from nemo.core.config import hydra_runner +from nemo.utils import logging -import time -from functools import wraps -import math @dataclass class DiarizationConfig: @@ -66,13 +59,13 @@ class DiarizationConfig: parallel_speaker_strategy: bool = True # General configs - session_len_sec: float = -1 # End-to-end diarization session length in seconds + session_len_sec: float = -1 # End-to-end diarization session length in seconds num_workers: int = 8 random_seed: Optional[int] = None # seed number going to be used in seed_everything() - log: bool = True # If True, log will be printed + log: bool = True # If True, log will be printed # Streaming diarization configs - streaming_mode: bool = True # If True, streaming diarization will be used. + streaming_mode: bool = True # If True, streaming diarization will be used. spkcache_len: int = 188 spkcache_refresh_rate: int = 0 fifo_len: int = 188 @@ -99,7 +92,7 @@ class DiarizationConfig: online_normalization: bool = False output_path: Optional[str] = None pad_and_drop_preencoded: bool = False - set_decoder: Optional[str] = None # ["ctc", "rnnt"] + set_decoder: Optional[str] = None # ["ctc", "rnnt"] att_context_size: Optional[list] = None generate_realtime_scripts: bool = True @@ -123,7 +116,7 @@ class DiarizationConfig: feat_len_sec: float = 0.01 finetune_realtime_ratio: float = 0.01 - spk_supervision: str = "diar" # ["diar", "rttm"] + spk_supervision: str = "diar" # ["diar", "rttm"] binary_diar_preds: bool = False @@ -132,6 +125,7 @@ def format_time(seconds): sec = seconds % 60 return f"{minutes}:{sec:05.2f}" + def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded): # for the first step there is no need to drop any tokens after the downsampling as no caching is being used if step_num == 0 and not pad_and_drop_preencoded: @@ -139,6 +133,7 @@ def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded): else: return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + def add_delay_for_real_time(cfg, chunk_audio, session_start_time, feat_frame_count, loop_end_time, loop_start_time): """ Add artificial delay for real-time mode by calculating the time difference between @@ -149,10 +144,18 @@ def add_delay_for_real_time(cfg, chunk_audio, session_start_time, feat_frame_cou """ time_diff = max(0, (time.time() - session_start_time) - feat_frame_count * cfg.feat_len_sec) eta_min_sec = format_time(time.time() - session_start_time) - logging.info(f"[ REAL TIME MODE ] min:sec - {eta_min_sec} " - f"Time difference for real-time mode: {time_diff:.4f} seconds") - time.sleep(max(0, (chunk_audio.shape[-1] - cfg.discarded_frames)*cfg.feat_len_sec - - (loop_end_time - loop_start_time) - time_diff * cfg.finetune_realtime_ratio)) + logging.info( + f"[ REAL TIME MODE ] min:sec - {eta_min_sec} " + f"Time difference for real-time mode: {time_diff:.4f} seconds" + ) + time.sleep( + max( + 0, + (chunk_audio.shape[-1] - cfg.discarded_frames) * cfg.feat_len_sec + - (loop_end_time - loop_start_time) + - time_diff * cfg.finetune_realtime_ratio, + ) + ) def write_seglst_file(seglst_dict_list, output_path): @@ -162,6 +165,7 @@ def write_seglst_file(seglst_dict_list, output_path): f.write(json.dumps(seglst_dict_list, indent=4) + '\n') logging.info(f"Saved the transcriptions of the streaming inference in\n:{output_path}") + def launch_serial_streaming( cfg, asr_model, @@ -189,24 +193,26 @@ def launch_serial_streaming( drop_extra_pre_encoded=drop_extra_pre_encoded, ) - feat_frame_count += (chunk_audio.shape[-1] - cfg.discarded_frames) + feat_frame_count += chunk_audio.shape[-1] - cfg.discarded_frames if cfg.real_time_mode: - add_delay_for_real_time(cfg, - chunk_audio=chunk_audio, - session_start_time=session_start_time, - feat_frame_count=feat_frame_count, - loop_end_time=time.time(), - loop_start_time=loop_start_time - ) + add_delay_for_real_time( + cfg, + chunk_audio=chunk_audio, + session_start_time=session_start_time, + feat_frame_count=feat_frame_count, + loop_end_time=time.time(), + loop_start_time=loop_start_time, + ) return multispk_asr_streamer + def launch_parallel_streaming( cfg, asr_model, diar_model, streaming_buffer, pad_and_drop_preencoded=False, - ): +): streaming_buffer_iter = iter(streaming_buffer) multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) @@ -263,11 +269,11 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: map_location = torch.device(f'cuda:{cfg.cuda}') if cfg.diar_model_path.endswith(".ckpt"): - diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.diar_model_path, - map_location=map_location, strict=False) + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.diar_model_path, map_location=map_location, strict=False + ) elif cfg.diar_model_path.endswith(".nemo"): - diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, - map_location=map_location) + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, map_location=map_location) else: raise ValueError("cfg.diar_model_path must end with.ckpt or.nemo!") @@ -372,7 +378,11 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.audio_file is not None: # Stream a single audio file - samples = [{'audio_filepath': cfg.audio_file,}] + samples = [ + { + 'audio_filepath': cfg.audio_file, + } + ] streaming_buffer = CacheAwareStreamingAudioBuffer( model=asr_model, online_normalization=online_normalization, @@ -399,7 +409,9 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: else: # Stream audio files in a manifest file in batched mode feat_per_sec = round(asr_model.cfg.preprocessor.window_stride * asr_model.cfg.encoder.subsampling_factor, 2) - samples, rttms_mask_mats = get_multi_talker_samples_from_manifest(cfg, manifest_file=cfg.manifest_file, feat_per_sec=feat_per_sec, max_spks=cfg.max_num_of_spks) + samples, rttms_mask_mats = get_multi_talker_samples_from_manifest( + cfg, manifest_file=cfg.manifest_file, feat_per_sec=feat_per_sec, max_spks=cfg.max_num_of_spks + ) cfg.batch_size = len(samples) # Note: rttms_mask_mats contains PyTorch tensors, so we pass it directly instead of storing in config if cfg.spk_supervision == "rttm": @@ -439,12 +451,15 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.output_path is not None and multispk_asr_streamer is not None: if cfg.parallel_speaker_strategy: multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples) - write_seglst_file(seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, - output_path=cfg.output_path) + write_seglst_file( + seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, output_path=cfg.output_path + ) else: multispk_asr_streamer.generate_seglst_dicts_from_serial_streaming(samples=samples) - write_seglst_file(seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, - output_path=cfg.output_path) + write_seglst_file( + seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, output_path=cfg.output_path + ) + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 16a4ee0463f5..06a28769536d 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -12,29 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import random +import re from typing import Dict, Optional, Tuple -import soundfile +import numpy as np +import soundfile import torch.utils.data -from lhotse.cut import MixedCut, MonoCut, MixTrack, PaddingCut +from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment, SupervisionSet +from lhotse.cut import MixedCut, MixTrack, MonoCut, PaddingCut from lhotse.dataset import AudioSamples -from lhotse.dataset.collation import collate_vectors, collate_matrices +from lhotse.dataset.collation import collate_matrices, collate_vectors from lhotse.utils import compute_num_samples -from lhotse import SupervisionSet, SupervisionSegment, MonoCut, Recording, CutSet, AudioSource - -import numpy as np from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + get_hidden_length_from_sample_length, + speaker_to_target, +) from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - speaker_to_target, - get_hidden_length_from_sample_length, -) class LhotseSpeechToTextSpkBpeDataset(torch.utils.data.Dataset): """ @@ -53,8 +52,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'a_sig_length': NeuralType(tuple('B'), LengthsType()), 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), - 'spk_targets': NeuralType(('B','T'), LabelsType()), - 'bg_spk_targets': NeuralType(('B','T'), LabelsType()), + 'spk_targets': NeuralType(('B', 'T'), LabelsType()), + 'bg_spk_targets': NeuralType(('B', 'T'), LabelsType()), } def __init__(self, cfg, tokenizer): @@ -77,14 +76,18 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: bg_spk_targets = [] if self.inference_mode: - speaker_targets = [speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame) for cut in cuts] + speaker_targets = [ + speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame) for cut in cuts + ] spk_targets = collate_matrices(speaker_targets, padding_value=0) return audio, audio_lens, None, None, spk_targets for idx, cut in enumerate(cuts): - speaker_targets, texts = speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame, return_text=True) - speaker_targets = speaker_targets.transpose(0, 1)[:len(texts)] + speaker_targets, texts = speaker_to_target( + cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame, return_text=True + ) + speaker_targets = speaker_targets.transpose(0, 1)[: len(texts)] target_speaker_id = random.choice(range(len(texts))) non_target_speaker_ids = [i for i in range(len(texts)) if i != target_speaker_id] @@ -95,10 +98,10 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: tokens.append(torch.as_tensor(self.tokenizer(text, cut.supervisions[0].language))) spk_targets.append(speaker_target) bg_spk_targets.append(bg_speaker_target) - + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=0) spk_targets = collate_vectors(spk_targets, padding_value=0) bg_spk_targets = collate_vectors(bg_spk_targets, padding_value=0) - - return audio, audio_lens, tokens, token_lens, spk_targets, bg_spk_targets \ No newline at end of file + + return audio, audio_lens, tokens, token_lens, spk_targets, bg_spk_targets diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 11fd592b0f40..ee277158c9a0 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -32,6 +32,7 @@ ) from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel, NeuralDiarizer +from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel @@ -42,4 +43,3 @@ SpeechEncDecSelfSupervisedModel, ) from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE -from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel diff --git a/nemo/collections/asr/models/multitalker_asr_models.py b/nemo/collections/asr/models/multitalker_asr_models.py index ea7a66eb15a7..887736d88510 100644 --- a/nemo/collections/asr/models/multitalker_asr_models.py +++ b/nemo/collections/asr/models/multitalker_asr_models.py @@ -14,22 +14,18 @@ # import os from typing import Any, Dict, List, Optional + import torch import torch.nn.functional as F from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text_lhotse_speaker import LhotseSpeechToTextSpkBpeDataset - -from nemo.collections.asr.parts.mixins import ( - TranscribeConfig, - TranscriptionReturnType, -) -from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin - from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionReturnType +from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config class EncDecMultiTalkerRNNTBPEModel(EncDecRNNTBPEModel, SpeakerKernelMixin): @@ -46,12 +42,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): with open_dict(config): config.global_rank = self.global_rank config.world_size = self.world_size - + return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, world_size=self.world_size, - dataset=LhotseSpeechToTextSpkBpeDataset(cfg = config, tokenizer=self.tokenizer,), + dataset=LhotseSpeechToTextSpkBpeDataset( + cfg=config, + tokenizer=self.tokenizer, + ), ) def training_step(self, batch, batch_nb): @@ -86,7 +85,7 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): batch = (signal, signal_len, transcript, transcript_len) return super()._transcribe_forward(batch, trcfg) - + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': """ Setup function for a temporary data loader which wraps the provided audio file. @@ -121,7 +120,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'use_bucketing': False, 'channel_selector': config.get('channel_selector', None), 'inference_mode': self.cfg.test_ds.get('inference_mode', True), - 'fixed_spk_id': config.get('fixed_spk_id', None) + 'fixed_spk_id': config.get('fixed_spk_id', None), } if config.get("augmentor"): @@ -129,4 +128,4 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) - return temporary_datalayer \ No newline at end of file + return temporary_datalayer diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 8d3953ff8302..75035dba0617 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -128,7 +128,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.max_batch_dur = self._cfg.get("max_batch_dur", 20000) self.concat_and_pad_script = torch.jit.script(self.sortformer_modules.concat_and_pad) - self.rttms_mask_mats: List[torch.Tensor] = None # Used when GT diarization needs to be tested. + self.rttms_mask_mats: List[torch.Tensor] = None # Used when GT diarization needs to be tested. def add_rttms_mask_mats(self, rttms_mask_mats, device: torch.device): """ @@ -140,7 +140,9 @@ def add_rttms_mask_mats(self, rttms_mask_mats, device: torch.device): if self.rttms_mask_mats is None: self.rttms_mask_mats = rttms_mask_mats.to(device) else: - raise ValueError(f"{self.rttms_mask_mats.shape}: rttms_mask_mats already exist but new one is being added.") + raise ValueError( + f"{self.rttms_mask_mats.shape}: rttms_mask_mats already exist but new one is being added." + ) def _init_loss_weights(self): pil_weight = self._cfg.get("pil_weight", 0.0) @@ -1141,4 +1143,4 @@ def diarize( num_workers=num_workers, verbose=verbose, override_config=override_config, - ) \ No newline at end of file + ) diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py index 7ea8ca2e1584..3c4f837dbf5a 100644 --- a/nemo/collections/asr/parts/mixins/__init__.py +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -14,13 +14,13 @@ from nemo.collections.asr.parts.mixins.asr_adapter_mixins import ASRAdapterModelMixin from nemo.collections.asr.parts.mixins.interctc_mixin import InterCTCMixin -from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.asr.parts.mixins.mixins import ( ASRAdapterModelMixin, ASRBPEMixin, ASRModuleMixin, DiarizationMixin, ) +from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.asr.parts.mixins.transcription import ( ASRTranscriptionMixin, TranscribeConfig, diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py index 9c4f6eea8109..afbc74e60936 100644 --- a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -12,48 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional + import torch import torch.nn as nn -from abc import ABC, abstractmethod from omegaconf import ListConfig from nemo.utils import logging __all__ = ['SpeakerKernelMixin'] -def get_spk_kernel_class( - spk_kernel_type, - input_size, - d_model, - dropout=0.5 -): + +def get_spk_kernel_class(spk_kernel_type, input_size, d_model, dropout=0.5): if spk_kernel_type == 'ff': - return nn.Sequential(nn.Linear(input_size, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, input_size)) + return nn.Sequential( + nn.Linear(input_size, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, input_size) + ) elif spk_kernel_type == 'conv2d': - return + return elif spk_kernel_type == 'mha': return + class SpeakerKernelMixin(ABC): """ Mixin class for models that need speaker kernel functionality. - + This mixin provides: - Speaker kernel initialization - Hook attachment for applying speaker kernels at specific encoder layers - Support for both active and background speaker kernels - + Models using this mixin should have the following config parameters: - spk_kernel_type: Type of speaker kernel ('mask', 'concat', 'sinusoidal') - spk_kernel_layers: List of layer indices where to apply speaker kernels - add_bg_spk_kernel: Whether to add background speaker kernels """ - + def _init_speaker_kernel_config(self, cfg): """ Initialize speaker kernel configuration from model config. - + Args: cfg: Model configuration containing speaker kernel parameters """ @@ -61,15 +61,15 @@ def _init_speaker_kernel_config(self, cfg): self.spk_kernel_type = cfg.get('spk_kernel_type', None) self.spk_kernel_layers = cfg.get('spk_kernel_layers', [0]) self.add_bg_spk_kernel = cfg.get('add_bg_spk_kernel', True) - + # Initialize speaker target containers self.spk_targets = None - if self.add_bg_spk_kernel: + if self.add_bg_spk_kernel: self.bg_spk_targets = None - + # Initialize speaker kernels self._init_spk_kernel() - + def _init_spk_kernel(self): """Initialize speaker kernel modules and register them to encoder layers.""" if not isinstance(self.spk_kernel_layers, ListConfig): @@ -82,21 +82,21 @@ def _init_spk_kernel(self): self.spk_kernels = torch.nn.ModuleDict() if self.add_bg_spk_kernel: self.bg_spk_kernels = torch.nn.ModuleDict() - + # Create kernel for each layer index for layer_idx in self.spk_kernel_layers: self.spk_kernels[str(layer_idx)] = get_spk_kernel_class( spk_kernel_type=self.spk_kernel_type, input_size=hidden_size, d_model=self.cfg.encoder.d_model, - dropout=0.5 + dropout=0.5, ) if self.add_bg_spk_kernel: self.bg_spk_kernels[str(layer_idx)] = get_spk_kernel_class( spk_kernel_type=self.spk_kernel_type, input_size=hidden_size, d_model=self.cfg.encoder.d_model, - dropout=0.5 + dropout=0.5, ) if self.spk_kernels: @@ -134,10 +134,10 @@ def _attach_spk_kernel_hooks(self): def _get_spk_kernel_hook_pre_layer(self, layer_idx: str): """ Returns a hook function for applying speaker kernel transformation. - + Args: layer_idx (str): Index of the layer to apply the kernel - + Returns: callable: Hook function that applies speaker kernel """ @@ -145,7 +145,7 @@ def _get_spk_kernel_hook_pre_layer(self, layer_idx: str): def hook_fn(module, args, kwargs): # Pre-hooks with with_kwargs=True must return a (new_args, new_kwargs) tuple. # The input tensor is passed as a keyword argument, so we find it in 'kwargs'. - + if 'x' in kwargs: x = kwargs['x'] x_spk = self.spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.spk_targets)) @@ -173,17 +173,18 @@ def hook_fn(module, args, kwargs): def _get_spk_kernel_hook_post_layer(self, layer_idx: str): """ Returns a hook function for applying speaker kernel transformation. - + Args: layer_idx (str): Index of the layer to apply the kernel - + Returns: callable: Hook function that applies speaker kernel """ + def hook_fn(module, input, output): if self.spk_targets is None: return output - + if isinstance(output, tuple): x, *cache = output else: @@ -200,7 +201,7 @@ def hook_fn(module, input, output): if isinstance(output, tuple): return (x, *cache) return x - + return hook_fn def _cleanup_speaker_kernel_hooks(self): @@ -217,11 +218,12 @@ def _cleanup_speaker_kernel_hooks(self): delattr(self, 'encoder_hooks') logging.info("Speaker kernel hooks cleaned up") - def set_speaker_targets(self, spk_targets: Optional[torch.Tensor] = None, - bg_spk_targets: Optional[torch.Tensor] = None): + def set_speaker_targets( + self, spk_targets: Optional[torch.Tensor] = None, bg_spk_targets: Optional[torch.Tensor] = None + ): """ Set speaker targets for the model. - + Args: spk_targets: Main speaker targets tensor bg_spk_targets: Background speaker targets tensor @@ -235,21 +237,23 @@ def clear_speaker_targets(self): self.spk_targets = None if self.add_bg_spk_kernel: self.bg_spk_targets = None - + def solve_length_mismatch(self, x: torch.Tensor, mask: torch.Tensor): """ Solve length mismatch between x and mask. """ if mask is None: mask = torch.ones_like(x[:, :, 0]) - logging.warning(f"Mask is None, triggering single speaker mode and assigning all ones with shape: {mask.shape}") + logging.warning( + f"Mask is None, triggering single speaker mode and assigning all ones with shape: {mask.shape}" + ) if mask.shape[1] < x.shape[1]: # pad zero to the left mask = torch.nn.functional.pad(mask, (x.shape[1] - mask.shape[1], 0), mode='constant', value=1) if mask.shape[1] > x.shape[1]: - mask = mask[:, -x.shape[1]:] + mask = mask[:, -x.shape[1] :] return mask @@ -268,4 +272,3 @@ def concat_with_speaker_targets(self, x: torch.Tensor, spk_targets: torch.Tensor mask = self.solve_length_mismatch(x, spk_targets) x_spk = x * mask.unsqueeze(2) return x_spk - \ No newline at end of file diff --git a/nemo/collections/asr/parts/mixins/streaming.py b/nemo/collections/asr/parts/mixins/streaming.py index 7a2d921af115..04056540ac04 100644 --- a/nemo/collections/asr/parts/mixins/streaming.py +++ b/nemo/collections/asr/parts/mixins/streaming.py @@ -20,7 +20,8 @@ class StreamingEncoder(ABC): @abstractmethod def setup_streaming_params( - self, max_look_ahead: int = 10000, + self, + max_look_ahead: int = 10000, ): """ This function sets the needed values and parameters to perform streaming. The configuration (CacheAwareStreamingConfig) need to be stored in self.streaming_cfg. diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index b0b75212f62c..dd1c95218013 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -11,38 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re -import math +import itertools import json -import random import logging -import itertools -from copy import deepcopy -from cytoolz import groupby +import math +import os +import random +import re import time from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import soundfile -from tqdm import tqdm -from scipy.stats import norm - +import soundfile as sf import torch.utils.data +from cytoolz import groupby +from lhotse import AudioSource, Recording, SupervisionSegment, SupervisionSet, dill_enabled +from lhotse.cut import Cut, CutSet, MixedCut, MixTrack, MonoCut from lhotse.cut.set import mix -from lhotse.cut import Cut, CutSet, MixedCut, MonoCut, MixTrack -from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording -from lhotse.utils import uuid4, compute_num_samples, ifnone from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator +from lhotse.utils import compute_num_samples, ifnone, uuid4 +from omegaconf import OmegaConf +from scipy.stats import norm +from tqdm import tqdm + from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator from nemo.collections.asr.parts.utils.data_simulation_utils import read_rir_manifest -from typing import Optional, Union, List, Tuple, Dict, Any -from omegaconf import OmegaConf -from dataclasses import dataclass, field -from typing import List, Optional -import soundfile as sf -import os @dataclass class SessionConfig: num_speakers: int = 1 @@ -50,6 +49,7 @@ class SessionConfig: session_length: int = 15 session_length_range: List[int] = field(default_factory=lambda: [10, 40]) + @dataclass class SessionParams: max_audio_read_sec: float = 20.0 @@ -82,6 +82,7 @@ class SessionParams: end_buffer: float = 0.5 random_offset: bool = True + @dataclass class OutputConfig: output_dir: str = "" @@ -89,6 +90,7 @@ class OutputConfig: overwrite_output: bool = True output_precision: int = 3 + @dataclass class BackgroundNoise: add_bg: bool = True @@ -99,6 +101,7 @@ class BackgroundNoise: snr_min: Optional[float] = None snr_max: Optional[float] = None + @dataclass class SegmentAugmentor: add_seg_aug: bool = False @@ -106,6 +109,7 @@ class SegmentAugmentor: min_gain_dbfs: float = -10.0 max_gain_dbfs: float = 10.0 + @dataclass class SessionAugmentor: add_sess_aug: bool = False @@ -113,11 +117,13 @@ class SessionAugmentor: min_white_noise_level: int = -90 max_white_noise_level: int = -46 + @dataclass class SpeakerEnforcement: enforce_num_speakers: bool = True enforce_time: List[float] = field(default_factory=lambda: [0.25, 0.75]) + @dataclass class SegmentManifest: window: float = 0.5 @@ -125,6 +131,7 @@ class SegmentManifest: step_count: int = 50 deci: int = 3 + @dataclass class RIRGeneration: use_rir: bool = False @@ -141,9 +148,11 @@ class RIRGeneration: att_diff: float = 15.0 att_max: float = 60.0 + @dataclass class DataSimConfig: """Configuration for data simulation.""" + manifest_filepath: str = "" sr: int = 16000 random_seed: int = 42 @@ -159,20 +168,23 @@ class DataSimConfig: segment_manifest: SegmentManifest = field(default_factory=SegmentManifest) rir_generation: RIRGeneration = field(default_factory=RIRGeneration) + @dataclass class MultiSpeakerSimulatorConfig: data_simulator: DataSimConfig = field(default_factory=DataSimConfig) + class Segment: def __init__(self, start, end, speaker_id, text): self.start = start self.end = end self.speaker_id = speaker_id self.text = text - + def __str__(self): return f"Segment(start={self.start}, end={self.end}, speaker_id={self.speaker_id}, text=\"{self.text}\")" + class SegList: def __init__(self, segments: List[Segment] = None, seglst_filepath: str = None): if segments is not None: @@ -181,8 +193,8 @@ def __init__(self, segments: List[Segment] = None, seglst_filepath: str = None): self._load_seglst(seglst_filepath) else: raise ValueError("Either segments or seglst_filepath must be provided") - - def _load_seglst(self, seglst_filepath: str|list[str]): + + def _load_seglst(self, seglst_filepath: str | list[str]): if isinstance(seglst_filepath, str): with open(seglst_filepath, 'r', encoding='utf-8') as f: seglst = json.load(f) @@ -200,21 +212,21 @@ def _load_seglst(self, seglst_filepath: str|list[str]): else: raise ValueError("seglst_filepath must be a string or a list of strings") self.sort() - + def __len__(self): return len(self.segments) - + def __getitem__(self, idx): return self.segments[idx] - + def __iter__(self): return iter(self.segments) - + def sort(self): self.segments.sort(key=lambda x: x.start) def get_segments(self, min_duration: float, max_duration: float): - + duration = random.uniform(min_duration, max_duration) first_segment_idx = random.randint(0, len(self) - 1) @@ -226,13 +238,17 @@ def get_segments(self, min_duration: float, max_duration: float): segments.append(self[i]) else: break - + return segments - - def get_text_from_segments(self, segments: list[Segment], speaker_token_style='<|spltoken*|>', speaker_token_position='sot'): + + def get_text_from_segments( + self, segments: list[Segment], speaker_token_style='<|spltoken*|>', speaker_token_position='sot' + ): text = '' speakers = set([segment.speaker_id for segment in segments]) - speaker2start = {spk_id: min(segment.start for segment in segments if segment.speaker_id == spk_id) for spk_id in speakers} + speaker2start = { + spk_id: min(segment.start for segment in segments if segment.speaker_id == spk_id) for spk_id in speakers + } sorted_speakers = sorted(speakers, key=lambda x: speaker2start[x]) speaker2token = {spk: speaker_token_style.replace('*', str(i)) for i, spk in enumerate(sorted_speakers)} for segment in segments: @@ -241,7 +257,6 @@ def get_text_from_segments(self, segments: list[Segment], speaker_token_style='< return text.strip() - def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: """ Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. @@ -534,18 +549,18 @@ def get_hidden_length_from_sample_length( mel_frame_count = math.ceil(num_samples / num_sample_per_mel_frame) hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) - + def speaker_to_target( a_cut, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, boundary_segments: bool = False, soft_label: bool = False, soft_thres: float = 0.5, ignore_num_spk_mismatch: bool = True, return_text: bool = False, - ): +): ''' Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) This function is needed for speaker diarization with ASR model trainings. @@ -560,7 +575,7 @@ def speaker_to_target( soft_thres (float): the threshold for the soft label, 0.5 by default. ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. return_text (bool): set to True to return the text of the speakers (if it is available), False by default. - + Returns: mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) ''' @@ -574,7 +589,7 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] for i, cut in enumerate(cut_list): @@ -585,15 +600,19 @@ def speaker_to_target( else: logging.warning(f"No rttm or supervisions found for cut {cut.id}") continue - + start = cut.offset if hasattr(cut, 'offset') else cut.start end = start + cut.duration recording_id = rttms[0].recording_id if len(rttms) > 0 else cut.recording_id - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=recording_id, rttms=rttms, start_after=start, end_before=end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=recording_id, start_after=start, end_before=end, adjust_offset=True) #, tolerance=0.0) - + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=recording_id, rttms=rttms, start_after=start, end_before=end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=recording_id, start_after=start, end_before=end, adjust_offset=True + ) # , tolerance=0.0) + for seg in segments_iterator: if seg.start < 0: seg.duration += seg.start @@ -603,29 +622,30 @@ def speaker_to_target( seg.start += offsets[i] segments_total.append(seg) # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} num_speakers = len(speaker_ats) - + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: mask = soft_mask else: mask = (soft_mask > soft_thres).float() - + if return_text: speaker2text = defaultdict(list) for seg in segments_total: @@ -635,6 +655,7 @@ def speaker_to_target( else: return mask + def read_seglst(seglst_filepath: str, session_id: Optional[str] = None): """ Read the seglst file and return a list of segments. @@ -648,17 +669,20 @@ def read_seglst(seglst_filepath: str, session_id: Optional[str] = None): start=float(seg['start_time']), duration=float(seg['end_time']) - float(seg['start_time']), text=seg['words'], - speaker=seg['speaker'] - ) for i, seg in enumerate(seglst) + speaker=seg['speaker'], + ) + for i, seg in enumerate(seglst) ] - -class MultiSpeakerMixtureGenerator(): + + +class MultiSpeakerMixtureGenerator: """ This class is used to simulate multi-speaker audio data, which can be used for multi-speaker ASR and speaker diarization training. """ + def __init__( - self, + self, manifest_filepath, sample_rate, simulator_type, @@ -673,7 +697,7 @@ def __init__( """ Args: cuts (CutSet): The cutset that contains single-speaker audio cuts. - Please make sure that the cuts have the 'speaker_id' attribute. + Please make sure that the cuts have the 'speaker_id' attribute. num_speakers (int): The number of speakers in the simulated audio. We only simulate the samples with the fixed number of speakers. The variation of the number of speakers is controlled by the weights in Lhotse dataloader config. @@ -691,10 +715,10 @@ def __init__( self.global_rank = global_rank self.world_size = world_size - self.manifest_filepath = manifest_filepath + self.manifest_filepath = manifest_filepath self.manifests = list(LazyJsonlIterator(manifest_filepath)) self.sample_rate = sample_rate - + self.min_duration = min_duration self.max_duration = max_duration self.min_delay = min_delay @@ -703,10 +727,7 @@ def __init__( print("====== simulator_type", simulator_type) - type2simulator = { - 'lsmix': self.LibriSpeechMixSimulator, - 'mixture_loader': self.MultiSpeakerMixtureLoader - } + type2simulator = {'lsmix': self.LibriSpeechMixSimulator, 'mixture_loader': self.MultiSpeakerMixtureLoader} self.simulator = type2simulator[simulator_type] @@ -718,7 +739,7 @@ def __init__( def __iter__(self): return self - + def __next__(self): self.count += 1 return self.simulator() @@ -744,17 +765,19 @@ def LibriSpeechMixSimulator(self): start=0.0, duration=mono_cuts[-1].duration, text=mono_cuts[-1].custom['text'], - speaker=speaker_id + speaker=speaker_id, ) ) - + tracks = [] offset = 0.0 for speaker_id, mono_cut in zip(sampled_speaker_ids, mono_cuts): tracks.append(MixTrack(cut=deepcopy(mono_cut), type=type(mono_cut), offset=offset)) offset += random.uniform(self.min_delay, mono_cut.duration) - - mixed_cut = MixedCut(id='lsmix_' + '_'.join([track.cut.id for track in tracks]) + '_' + str(uuid4()), tracks=tracks) + + mixed_cut = MixedCut( + id='lsmix_' + '_'.join([track.cut.id for track in tracks]) + '_' + str(uuid4()), tracks=tracks + ) return mixed_cut @@ -784,12 +807,18 @@ def MultiSpeakerMixtureLoader(self): supervisions = sorted(supervisions, key=lambda x: x.start) segment_offset, segment_duration = self._get_offset_and_duration(supervisions) - + json_dict = { 'audio_filepath': audio_filepath, 'duration': segment_duration, 'offset': segment_offset, - 'supervisions': find_segments_from_rttm(recording_id=supervisions[0].recording_id, rttms=SupervisionSet(supervisions), start_after=segment_offset, end_before=segment_offset + segment_duration, adjust_offset=False) + 'supervisions': find_segments_from_rttm( + recording_id=supervisions[0].recording_id, + rttms=SupervisionSet(supervisions), + start_after=segment_offset, + end_before=segment_offset + segment_duration, + adjust_offset=False, + ), } cut = self._json_to_cut(json_dict) @@ -803,7 +832,7 @@ def _get_offset_and_duration(self, supervisions): non_overlap_supervisions_indices = self._get_non_overlap_supervisions_indices(supervisions) # find the start and the end of the segment start_idx = random.choice(non_overlap_supervisions_indices) - end_idx = start_idx + end_idx = start_idx offset = supervisions[start_idx].start for i in range(start_idx + 1, len(supervisions)): end_idx = i @@ -816,7 +845,7 @@ def _get_offset_and_duration(self, supervisions): segment_duration = supervisions[end_idx].end - offset return segment_offset, segment_duration - + def _get_non_overlap_supervisions_indices(self, supervisions): """ Get the indices of the non-overlapping supervisions. @@ -829,7 +858,7 @@ def _get_non_overlap_supervisions_indices(self, supervisions): non_overlap_supervisions_indices.append(i) max_end = max(max_end, supervisions[i].end) return non_overlap_supervisions_indices - + def _json_to_cut(self, json_dict): """ Convert a json dictionary to a Cut instance. @@ -839,7 +868,10 @@ def _json_to_cut(self, json_dict): offset = json_dict.get("offset", 0.0) supervisions = json_dict.get("supervisions", []) cut = self._create_cut( - audio_path=audio_path, offset=offset, duration=duration, sampling_rate=json_dict.get("sampling_rate", None), + audio_path=audio_path, + offset=offset, + duration=duration, + sampling_rate=json_dict.get("sampling_rate", None), ) # Note that start=0 and not start=offset because supervision's start if relative to the # start of the cut; and cut.start is already set to offset @@ -864,7 +896,7 @@ def _create_cut( sampling_rate: int | None = None, channel: int = 0, ) -> Cut: - + recording = self._create_recording(audio_path, duration, sampling_rate) cut = recording.to_cut() if isinstance(cut.channel, list) and len(cut.channel) > 1: @@ -873,7 +905,7 @@ def _create_cut( cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" return cut - + def _create_recording( self, audio_path: str, @@ -891,4 +923,4 @@ def _create_recording( channel_ids=[0], ) else: - return Recording.from_file(audio_path) \ No newline at end of file + return Recording.from_file(audio_path) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index d90e8a604231..e44bf63e11b9 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -67,10 +67,13 @@ def get_cleaned_base_path(output_dir: str, overwrite_output: bool = True) -> str def binary_search_alignments( - inds: List[int], max_audio_read_sec: float, min_alignment_count: int, alignments: List[float], + inds: List[int], + max_audio_read_sec: float, + min_alignment_count: int, + alignments: List[float], ) -> int: """ - Binary search to find the index of the alignment that satisfies the maximum audio read duration, + Binary search to find the index of the alignment that satisfies the maximum audio read duration, `max_audio_read_sec`. This is used to avoid reading the short audio files. NOTE: `offset_max` should be at least 1 to avoid feeding max=0 to random sampling function. @@ -103,7 +106,10 @@ def binary_search_alignments( def get_subset_of_audio_manifest( - audio_manifest: dict, offset_index: int, max_audio_read_sec: float, min_alignment_count: int, + audio_manifest: dict, + offset_index: int, + max_audio_read_sec: float, + min_alignment_count: int, ) -> dict: """ Get a subset of `audio_manifest` for faster audio-file reading. @@ -202,7 +208,6 @@ def read_audio_from_buffer( def perturb_audio( audio: torch.Tensor, sr: int, augmentor: Optional[AudioAugmentor] = None, device: Optional[torch.device] = None ) -> torch.Tensor: - """ Perturb the audio (segment or session) using audio augmentor. @@ -285,7 +290,10 @@ def get_scaled_audio_signal( def get_desired_avg_power_noise( - power_array: float, snr_min: float, snr_max: float, background_noise_snr: float, + power_array: float, + snr_min: float, + snr_max: float, + background_noise_snr: float, ): """ Calculate the desired average power of the noise. @@ -333,7 +341,7 @@ def get_background_noise( background_noise_snr (float): SNR of the background noise. seed (int): Seed for random number generator. device (torch.device): Device to use. - + Returns: bg_array (tensor): Tensor containing background noise. desired_snr (float): Desired SNR for adding background noise. @@ -358,10 +366,10 @@ def get_background_noise( ) # noise_segment_list.append( noise_manifest_dict = copy.deepcopy(audio_manifest) - noise_manifest_dict['duration'] = float(min(len(audio_file), len_array - running_len_samples-1) / sr) - noise_manifest_dict['offset'] = 0 + noise_manifest_dict['duration'] = float(min(len(audio_file), len_array - running_len_samples - 1) / sr) + noise_manifest_dict['offset'] = 0 noise_manifest_dict['volume'] = 1.0 - noise_manifest_dict['mixed_cut_offset'] = last_mixed_cut_offset + noise_manifest_dict['mixed_cut_offset'] = last_mixed_cut_offset last_mixed_cut_offset += noise_manifest_dict['duration'] noise_segment_list.append(noise_manifest_dict) @@ -392,7 +400,7 @@ def get_random_offset_index( min_alignment_count: int = 2, ) -> int: """ - Get an index for randomly accessing the silence in alignment timestamps. + Get an index for randomly accessing the silence in alignment timestamps. Args: audio_manifest (dict): Audio manifest dictionary. @@ -402,7 +410,7 @@ def get_random_offset_index( max_audio_read_sec (float): Maximum audio read duration in seconds. (Default: 2.5) min_alignment_count (int): Minimum number of alignment timestamps. (Default: 2) - Returns: + Returns: (int): Random offset index smaller than `offset_count`. """ if len(audio_manifest['alignments']) <= min_alignment_count: @@ -455,7 +463,7 @@ def get_speaker_ids(sess_idx: int, speaker_samples: dict, permutated_speaker_ind speaker_ids (list): List of speaker IDs """ all_speaker_ids = list(speaker_samples.keys()) - # Measure the length of permutated_speaker_inds and mod the sess_idx number so that + # Measure the length of permutated_speaker_inds and mod the sess_idx number so that # sess_idx is always less than the length of permutated_speaker_inds sess_idx_circular = sess_idx % permutated_speaker_inds.shape[0] idx_list = permutated_speaker_inds[sess_idx_circular, :] @@ -474,7 +482,7 @@ def build_speaker_samples_map(manifest: dict, tqdm_bar: bool = False) -> dict: speaker_samples = defaultdict(list) # logging.info("Building speaker to samples map...") for sample in tqdm(manifest, total=len(manifest), disable=not tqdm_bar): - # for sample in manifest: + # for sample in manifest: speaker_id = sample['speaker_id'] speaker_samples[speaker_id].append(sample) return speaker_samples @@ -535,7 +543,7 @@ def get_speaker_samples(speaker_ids: List[str], speaker_samples: dict) -> Dict[s Args: speaker_ids (list): LibriSpeech speaker IDs for each speaker in the current session. speaker_samples (dict): Dictionary mapping speaker ID to their list of samples. - + Returns: speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. """ @@ -563,7 +571,10 @@ def add_silence_to_alignments(audio_manifest: dict): def load_speaker_sample( - speaker_wav_align_map: List[dict], speaker_ids: List[str], speaker_turn: int, min_alignment_count: int, + speaker_wav_align_map: List[dict], + speaker_ids: List[str], + speaker_turn: int, + min_alignment_count: int, ) -> str: """ Load a sample for the selected speaker ID. @@ -575,7 +586,7 @@ def load_speaker_sample( speaker_turn (int): Current speaker turn. output_precision (int): Precision of the output alignments in integer. min_alignment_count (int): Minimum number of alignments in the audio file. - + Returns: audio_manifest (dict): Audio manifest dictionary containing the wav filepath, words and alignments. """ @@ -623,12 +634,15 @@ def get_split_points_in_alignments( splits = [] for i in range(len(words)): if words[i] == "" and i != 0 and i != len(words) - 1: - # if words[i] == "" and i != 0 and i != len(words) - 1: + # if words[i] == "" and i != 0 and i != len(words) - 1: silence_length = alignments[i] - alignments[i - 1] if silence_length > 2 * split_buffer: # split utterance on silence new_end = alignments[i - 1] + split_buffer splits.append( - [int(new_start * sr), int(new_end * sr),] + [ + int(new_start * sr), + int(new_end * sr), + ] ) new_start = alignments[i] - split_buffer # The last split point should be added @@ -667,12 +681,12 @@ class DataAnnotator(object): Class containing the functions that create RTTM, CTM, JSON files. Arguments in config: - + data_simulator: session_config: num_speakers (int): Number of unique speakers per multispeaker audio session session_params: - split_buffer (float): Split RTTM labels if greater than twice this amount of silence (to avoid long gaps between + split_buffer (float): Split RTTM labels if greater than twice this amount of silence (to avoid long gaps between utterances as being labelled as speech) outputs: output_dir (str): Output directory for audio sessions and corresponding label files @@ -715,9 +729,14 @@ def init_annotation_lists(self): self.annote_lists[file_type] = [] def create_new_rttm_entry( - self, words: List[str], alignments: List[float], start: int, end: int, speaker_id: int, add_split_buffer: bool = False + self, + words: List[str], + alignments: List[float], + start: int, + end: int, + speaker_id: int, + add_split_buffer: bool = False, ) -> List[str]: - """ Create new RTTM entries (to write to output rttm file) @@ -727,7 +746,7 @@ def create_new_rttm_entry( start (int): Current start of the audio file being inserted. end (int): End of the audio file being inserted. speaker_id (int): LibriSpeech speaker ID for the current entry. - + Returns: rttm_list (list): List of rttm entries """ @@ -742,7 +761,7 @@ def create_new_rttm_entry( ): # split utterance on silence new_end = start + alignments[i - 1] silence_duration = alignments[i] - alignments[i - 1] - + # new_end = start + alignments[i - 1] + self._params.data_simulator.session_params.split_buffer # import ipdb; ipdb.set_trace() @@ -751,7 +770,7 @@ def create_new_rttm_entry( t_stt = round(float(new_start), self._params.data_simulator.outputs.output_precision) t_end = round(float(new_end), self._params.data_simulator.outputs.output_precision) rttm_list.append(f"{t_stt} {t_end} {speaker_id}") - new_start = start + alignments[i] + new_start = start + alignments[i] # new_start = start + alignments[i] - self._params.data_simulator.session_params.split_buffer t_stt = round(float(new_start), self._params.data_simulator.outputs.output_precision) @@ -798,7 +817,7 @@ def create_new_json_entry( "uem_filepath": None, } return meta - + def create_ctm_entry_from_segment_list( self, source_segment_list, session_name: str, speaker_id: int, start: int ) -> List[str]: @@ -811,7 +830,7 @@ def create_ctm_entry_from_segment_list( session_name (str): Current session name. speaker_id (int): LibriSpeech speaker ID for the current entry. start (int): Current start of the audio file being inserted. - + Returns: arr (list): List of ctm entries """ @@ -830,8 +849,14 @@ def create_ctm_entry_from_segment_list( ): # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 # prev_align = 0 if i == 0 else alignments[i - 1] # align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) - align1 = round(float(start_offset + alignments[i] - alignment_offset), self._params.data_simulator.outputs.output_precision) - align2 = round(float(start_offset + alignments[i+1] - alignment_offset - align1), self._params.data_simulator.outputs.output_precision) + align1 = round( + float(start_offset + alignments[i] - alignment_offset), + self._params.data_simulator.outputs.output_precision, + ) + align2 = round( + float(start_offset + alignments[i + 1] - alignment_offset - align1), + self._params.data_simulator.outputs.output_precision, + ) text = get_ctm_line( source=session_name, channel=1, @@ -846,7 +871,12 @@ def create_ctm_entry_from_segment_list( return arr def create_new_ctm_entry( - self, words: List[str], alignments: List[float], session_name: str, speaker_id: int, start: int, + self, + words: List[str], + alignments: List[float], + session_name: str, + speaker_id: int, + start: int, ) -> List[str]: """ Create new CTM entry (to write to output ctm file) @@ -857,7 +887,7 @@ def create_new_ctm_entry( session_name (str): Current session name. speaker_id (int): LibriSpeech speaker ID for the current entry. start (int): Current start of the audio file being inserted. - + Returns: arr (list): List of ctm entries """ @@ -946,12 +976,15 @@ def write_annotation_rttm_and_ctm(self, basepath: str, filename: str): json_list (list): List of JSON entries. ctm_list (list): List of CTM entries. """ - labels_to_rttmfile(self.annote_lists['rttm'], os.path.join(basepath, filename), self._params.data_simulator.outputs.output_dir) + labels_to_rttmfile( + self.annote_lists['rttm'], os.path.join(basepath, filename), self._params.data_simulator.outputs.output_dir + ) write_ctm(os.path.join(basepath, filename + '.ctm'), self.annote_lists['ctm']) + class SpeechSampler(object): """ - Class for sampling speech samples for Multispeaker Audio Session Simulator + Class for sampling speech samples for Multispeaker Audio Session Simulator Args: cfg: OmegaConf configuration loaded from yaml file. @@ -969,23 +1002,23 @@ class SpeechSampler(object): self.per_overlap_min_len (int): Minimum number of overlap samples in the overlap segment. self.per_overlap_max_len (int): Maximum number of overlap samples in the overlap segment. - data_simulator: - session_params: + data_simulator: + session_params: mean_silence (float): Mean proportion of silence to speaking time in the audio session. Should be in range [0, 1). - mean_silence_var (float): Variance for mean silence in all audio sessions. + mean_silence_var (float): Variance for mean silence in all audio sessions. This value should be 0 <= mean_silence_var < mean_silence * (1 - mean_silence). per_silence_var (float): Variance for each silence in an audio session, set large values (e.g., 20) for de-correlation. per_silence_min (float): Minimum duration for each silence, default to 0. per_silence_max (float): Maximum duration for each silence, default to -1 for no maximum. - - mean_overlap (float): Mean proportion of overlap in the overall non-silence duration. Should be in range [0, 1) and + + mean_overlap (float): Mean proportion of overlap in the overall non-silence duration. Should be in range [0, 1) and recommend [0, 0.15] range for accurate results. - mean_overlap_var (float): Variance for mean overlap in all audio sessions. + mean_overlap_var (float): Variance for mean overlap in all audio sessions. This value should be 0 <= mean_overlap_var < mean_overlap * (1 - mean_overlap). - per_overlap_var (float): Variance for per overlap in each session, set large values to de-correlate silence lengths + per_overlap_var (float): Variance for per overlap in each session, set large values to de-correlate silence lengths with the latest speech segment lengths per_overlap_min (float): Minimum per overlap duration in seconds - per_overlap_max (float): Maximum per overlap duration in seconds, set -1 for no maximum + per_overlap_max (float): Maximum per overlap duration in seconds, set -1 for no maximum """ def __init__(self, cfg): @@ -1029,7 +1062,7 @@ def _mean_var_to_a_and_b(self, mean: float, var: float) -> Tuple[float, float]: Returns: Tuple[float, float]: a and b parameters for beta distribution. """ - a = mean ** 2 * (1 - mean) / var - mean + a = mean**2 * (1 - mean) / var - mean b = mean * (1 - mean) ** 2 / var - (1 - mean) return a, b @@ -1072,7 +1105,7 @@ def _init_overlap_params(self): def silence_vs_overlap_selector(self, running_len_samples: int, non_silence_len_samples: int) -> bool: """ - Compare the current silence ratio to the current overlap ratio. Switch to either silence or overlap mode according + Compare the current silence ratio to the current overlap ratio. Switch to either silence or overlap mode according to the amount of the gap between current ratio and session mean in config. Args: @@ -1106,7 +1139,7 @@ def get_session_silence_mean(self): 0 < mean_silence_var < mean_silence * (1 - mean_silence) Args: - silence_mean (float): + silence_mean (float): Target mean silence for the current session """ self._init_silence_params() @@ -1158,16 +1191,16 @@ def sample_from_silence_model(self, running_len_samples: int) -> int: Sample from the silence model to determine the amount of silence to add between sentences. Gamma distribution is employed for modeling the highly skewed distribution of silence length distribution. When we add silence between sentences, we want to ensure that the proportion of silence meets the `sess_silence_mean`. - Thus, [Session Silence Mean] = [Total Running Silence Time] / [Total Running Session Time] equation holds. We employ the following + Thus, [Session Silence Mean] = [Total Running Silence Time] / [Total Running Session Time] equation holds. We employ the following formula to determine the amount of silence to add, which is `silence_mean`: self.sess_silence_mean = (silence_mean + self.running_silence_len_samples) / (silence_mean + running_len_samples) - The above equation is setting `silence_mean` to yield the desired silence ratio `self.sess_silence_mean`. + The above equation is setting `silence_mean` to yield the desired silence ratio `self.sess_silence_mean`. We use the above `silence_mean` value to sample silence-length for each silence occurrence. Args: - running_len_samples (int): + running_len_samples (int): Running length of the session (in terms of number of samples). session_len_samples (int): Targeted total session length (in terms of number of samples). @@ -1182,11 +1215,7 @@ def sample_from_silence_model(self, running_len_samples: int) -> int: if silence_mean > 0: self.per_silence_var = self._params.data_simulator.session_params.per_silence_var silence_amount = ( - int( - gamma( - a=(silence_mean ** 2) / self.per_silence_var, scale=self.per_silence_var / silence_mean - ).rvs() - ) + int(gamma(a=(silence_mean**2) / self.per_silence_var, scale=self.per_silence_var / silence_mean).rvs()) if self.per_silence_var > 0 else int(silence_mean) ) @@ -1205,15 +1234,15 @@ def sample_from_overlap_model(self, non_silence_len_samples: int): self.sess_overlap_mean = (overlap_mean + self.running_overlap_len_samples) / (non_silence_len_samples - overlap_mean) - The above equation is setting `overlap_mean` to yield the desired overlap ratio `self.sess_overlap_mean`. + The above equation is setting `overlap_mean` to yield the desired overlap ratio `self.sess_overlap_mean`. We use the above `overlap_mean` value to sample overlap-length for each overlap occurrence. - + Args: - non_silence_len_samples (int): + non_silence_len_samples (int): The total amount of non-silence (speech) region regardless of overlap status Returns: - desired_overlap_amount (int): + desired_overlap_amount (int): Amount of overlap between segments (in terms of number of samples). """ overlap_mean = ((self.sess_overlap_mean * non_silence_len_samples) - self.running_overlap_len_samples) / ( @@ -1223,7 +1252,7 @@ def sample_from_overlap_model(self, non_silence_len_samples: int): if overlap_mean > 0: desired_overlap_amount = ( - int(gamma(a=overlap_mean ** 2 / self.per_overlap_var, scale=self.per_overlap_var / overlap_mean).rvs()) + int(gamma(a=overlap_mean**2 / self.per_overlap_var, scale=self.per_overlap_var / overlap_mean).rvs()) if self.per_overlap_var > 0 else int(overlap_mean) ) @@ -1239,7 +1268,7 @@ def sample_noise_manifest(self, noise_manifest: dict) -> list: Sample noise manifest to a specified count `num_noise_files` for the current simulated audio session. Args: - noise_manifest (list): + noise_manifest (list): List of noise source samples to be sampled from. Returns: @@ -1252,4 +1281,4 @@ def sample_noise_manifest(self, noise_manifest: dict) -> list: selected_noise_ids = np.random.choice(range(len(noise_manifest)), num_noise_files, replace=False) for k in selected_noise_ids: sampled_noise_manifest.append(noise_manifest[k]) - return sampled_noise_manifest \ No newline at end of file + return sampled_noise_manifest diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index 82f6a081c602..6d3ac54fdaa4 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -16,24 +16,25 @@ import csv import json import os -from collections import defaultdict, OrderedDict as od +from collections import OrderedDict as od +from collections import defaultdict from datetime import datetime -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple import numpy as np from pyannote.metrics.diarization import DiarizationErrorRate -from nemo.collections.asr.metrics.der import concat_perm_word_error_rate, calculate_session_cpWER +from nemo.collections.asr.metrics.der import calculate_session_cpWER, concat_perm_word_error_rate from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ClusteringDiarizer from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, + generate_diarization_output_lines, get_uniqname_from_filepath, + labels_to_pyannote_object, labels_to_rttmfile, rttm_to_labels, write_rttm2manifest, - generate_diarization_output_lines, - labels_to_pyannote_object, ) from nemo.utils import logging @@ -64,7 +65,6 @@ def get_color_palette() -> Dict[str, str]: } - def dump_json_to_file(file_path: str, session_trans_dict: dict): """ Write a json file from the session_trans_dict dictionary. @@ -92,6 +92,7 @@ def write_txt(w_path: str, val: str): with open(w_path, "w") as output: output.write(val + '\n') + def init_session_trans_dict(uniq_id: str, n_spk: int): """ Initialize json (in dictionary variable) formats for session level result and Gecko style json. @@ -110,6 +111,7 @@ def init_session_trans_dict(uniq_id: str, n_spk: int): } ) + def init_session_gecko_dict(): """ Initialize a dictionary format for Gecko style json. @@ -120,6 +122,7 @@ def init_session_gecko_dict(): """ return od({'schemaVersion': 2.0, 'monologues': []}) + def convert_ctm_to_text(ctm_file_path: str) -> Tuple[List[str], str]: """ Convert ctm file into a list containing transcription (space seperated string) per each speaker. @@ -234,8 +237,7 @@ def convert_word_dict_seq_to_ctm( ctm_lines.append(ctm_line_str) return ctm_lines - - def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_chars_in_line: int = 90) -> str: + def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_chars_in_line: int = 90) -> str: """ Break the lines in the transcript. @@ -267,8 +269,11 @@ def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_ch return_string_out = '\n'.join(return_string_out) return return_string_out + def get_total_result_dict( - der_results: Dict[str, Dict[str, float]], wer_results: Dict[str, Dict[str, float]], csv_columns: List[str], + der_results: Dict[str, Dict[str, float]], + wer_results: Dict[str, Dict[str, float]], + csv_columns: List[str], ): """ Merge WER results and DER results into a single dictionary variable. @@ -338,13 +343,14 @@ def get_num_of_spk_from_labels(labels: List[str]) -> int: spk_set = [x.split(' ')[-1].strip() for x in labels] return len(set(spk_set)) + def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_start_time=False): """ Read a seglst file and return the speaker & text information dictionary. Args: seglst_filepath: path to the seglst file - seglst format: + seglst format: [ { "session_id": "Bed008", @@ -367,7 +373,7 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star seglst = [] with open(seglst_filepath, 'r') as f: seglst_lines = json.loads(f.read()) - + for idx, line in enumerate(seglst_lines): spk, start, end = line['speaker'], float(line['start_time']), float(line['end_time']) dur = round(end - start, round_digits) @@ -384,13 +390,14 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star 'end_time': end, 'duration': dur, } - ) + ) if sort_by_start_time: seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) if return_rttm: return seglst, rttm_lines return seglst + def convert_seglst(seglst, all_speakers): ''' convert the seglst to a format that can be used for scoring @@ -402,7 +409,7 @@ def convert_seglst(seglst, all_speakers): timestamps: (list of list) [ [[st1, et1], [st2, et2]], # timestamps list for speaker 1 - [[st1, et1], ...], # timestamps list for speaker 2 + [[st1, et1], ...], # timestamps list for speaker 2 ...] words (list[[s1], [s2], [s3], [s4]]): list of words for each speaker 1 to 4 ''' @@ -428,17 +435,15 @@ def get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels): prev_speaker = speaker sentences, terms_list = [], [] - sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} - + sentence = {'speaker': speaker, 'start_time': start_point, 'end_time': end_point, 'text': ''} + for k, word_dict in enumerate(word_dict_seq_list): word, speaker = word_dict['word'], word_dict['speaker'] word_seq_list.append(word) start_point, end_point = word_dict['start_time'], word_dict['end_time'] if speaker != prev_speaker: if len(terms_list) != 0: - gecko_dict['monologues'].append( - {'speaker': {'name': None, 'id': prev_speaker}, 'terms': terms_list} - ) + gecko_dict['monologues'].append({'speaker': {'name': None, 'id': prev_speaker}, 'terms': terms_list}) terms_list = [] # remove trailing space in text @@ -474,11 +479,9 @@ def get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels): session_trans_dict['sentences'] = sentences gecko_dict['monologues'].append({'speaker': {'name': None, 'id': speaker}, 'terms': terms_list}) return session_trans_dict, gecko_dict, audacity_label_words, sentences - -def print_sentences(sentences: List[Dict[str, float]], - color_palette: Dict[str, str], - params: Dict[str, bool]) -> None: + +def print_sentences(sentences: List[Dict[str, float]], color_palette: Dict[str, str], params: Dict[str, bool]) -> None: """ Print a transcript with speaker labels and timestamps. @@ -494,7 +497,7 @@ def print_sentences(sentences: List[Dict[str, float]], string_out = '' # time_color = color_palette.get('black', '\033[0;30m') time_color = color_palette.get('white', '\033[0;30m') - + for sentence in sentences: # extract info speaker = sentence['speaker'] @@ -533,13 +536,14 @@ def print_sentences(sentences: List[Dict[str, float]], return string_out + def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_start_time=False, sort_by_end_time=False): """ Read a seglst file and return the speaker & text information dictionary. Args: seglst_filepath: path to the seglst file - seglst format: + seglst format: [ { "session_id": "Bed008", @@ -562,7 +566,7 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star seglst = [] with open(seglst_filepath, 'r') as f: seglst_lines = json.loads(f.read()) - + for idx, line in enumerate(seglst_lines): spk, start, end = line['speaker'], float(line['start_time']), float(line['end_time']) dur = round(end - start, round_digits) @@ -579,7 +583,7 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star 'end_time': end, 'duration': dur, } - ) + ) if sort_by_start_time and sort_by_end_time: raise ValueError("Cannot sort by both start and end time") if sort_by_start_time: @@ -590,6 +594,7 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star return seglst, rttm_lines return seglst + def convert_seglst(seglst, all_speakers): ''' convert the seglst to a format that can be used for scoring @@ -601,7 +606,7 @@ def convert_seglst(seglst, all_speakers): timestamps: (list of list) [ [[st1, et1], [st2, et2]], # timestamps list for speaker 1 - [[st1, et1], ...], # timestamps list for speaker 2 + [[st1, et1], ...], # timestamps list for speaker 2 ...] words (list[[s1], [s2], [s3], [s4]]): list of words for each speaker 1 to 4 ''' @@ -617,10 +622,8 @@ def convert_seglst(seglst, all_speakers): return timestamps, words -def chunk_seglst( - seglst : List[Dict], - chunk_size: float = 10.0 -): + +def chunk_seglst(seglst: List[Dict], chunk_size: float = 10.0): ''' Get chunked timestamps and words for each speaker @@ -636,26 +639,26 @@ def chunk_seglst( chunk_id2timestamps = defaultdict(list) speakers = set() session_ids = set() - + for segment in seglst: session_id = segment['session_id'] start_time = segment['start_time'] end_time = segment['end_time'] - + # Determine interval bounds chunk_start = int(start_time // chunk_size) chunk_end = int(end_time // chunk_size) - + # Split and assign the segment across overlapping intervals words = segment['words'] for chunk_idx in range(chunk_start, chunk_end + 1): chunk_start_time = chunk_idx * chunk_size chunk_end_time = (chunk_idx + 1) * chunk_size - + # Calculate the adjusted start and end times for the split segment segment_start = max(start_time, chunk_start_time) segment_end = min(end_time, chunk_end_time) - + # Create a split segment and add it to the corresponding interval split_segment = { 'session_id': session_id, @@ -676,16 +679,17 @@ def chunk_seglst( session_id = None else: session_id = session_ids.pop() - + return chunk_id2timestamps, speakers, session_id + # def streaming_evaluation( # ref_seglst: List[Dict], # ref_rttm_labels: List[str], # hyp_seglst: List[Dict], -# collar: float = 0.25, -# ignore_overlap: bool = False, -# verbose: bool = True, +# collar: float = 0.25, +# ignore_overlap: bool = False, +# verbose: bool = True, # chunk_size: float = 10.0, # ): # """ @@ -709,7 +713,7 @@ def chunk_seglst( # ref_session_id = hyp_session_id # assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" - + # # Only care about the sessions in reference only # session_id = ref_session_id # ref_speaker_words = defaultdict(list) @@ -736,12 +740,12 @@ def chunk_seglst( # hyp_labels = generate_diarization_output_lines(speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers)) # reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) # hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) - + # for idx, speaker in enumerate(ref_speakers): # ref_speaker_words[idx] += ref_speaker_word[idx] # for idx, speaker in enumerate(hyp_speakers): # hyp_speaker_words[idx] += hyp_speaker_word[idx] - + # der_met = der_metric(reference, hypothesis) # cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) @@ -751,9 +755,10 @@ def chunk_seglst( # der_list.append(abs(der_metric) * 100) # cpwer_list.append(cpWER) - + # return der_list, cpwer_list + class OnlineEvaluation: """ A class designed for performing online evaluation of diarization and ASR. @@ -768,16 +773,17 @@ class OnlineEvaluation: ignore_overlap (bool): Whether to ignore overlapping segments verbose (bool): - Whether to print verbose output + Whether to print verbose output """ - def __init__(self, + def __init__( + self, ref_seglst: List[Dict], ref_rttm_labels: List[str], hyp_seglst: Optional[List[Dict]] = None, - collar: float = 0.25, - ignore_overlap: bool = False, - verbose: bool = True, + collar: float = 0.25, + ignore_overlap: bool = False, + verbose: bool = True, ): self.ref_seglst = ref_seglst self.ref_rttm_labels = ref_rttm_labels @@ -802,8 +808,8 @@ def evaluate_inloop(self, hyp_seglst, end_step_time=0.0): if end_step_time > self.ref_seglst[self.current_idx]['end_time']: self.current_idx += 1 is_update = True - ref_seglst = self.ref_seglst[:self.current_idx] - der_cumul, cpwer_cumul= self.evaluate(ref_seglst, hyp_seglst, chunk_size=-1, verbose=False) + ref_seglst = self.ref_seglst[: self.current_idx] + der_cumul, cpwer_cumul = self.evaluate(ref_seglst, hyp_seglst, chunk_size=-1, verbose=False) der, cpwer = der_cumul[-1], cpwer_cumul[-1] if self.verbose: logging.info(f"Session ID: {self.ref_seglst[0]['session_id']} from 0.0s to {end_step_time:.3f}s") @@ -840,12 +846,12 @@ def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): hyp_session_id = ref_session_id assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" - + # Only care about the sessions in reference only session_id = ref_session_id ref_speaker_words = defaultdict(list) hyp_speaker_words = defaultdict(list) - + der_metric = DiarizationErrorRate(collar=2 * self.collar, skip_overlap=self.ignore_overlap) cpwer_metric = calculate_session_cpWER der_list, cpwer_list = [], [] @@ -860,11 +866,15 @@ def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) - ref_labels = generate_diarization_output_lines(speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers)) - hyp_labels = generate_diarization_output_lines(speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers)) + ref_labels = generate_diarization_output_lines( + speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers) + ) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers) + ) reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) - + for idx, speaker in enumerate(ref_speakers): ref_speaker_words[idx] += ref_speaker_word[idx] for idx, speaker in enumerate(hyp_speakers): @@ -873,18 +883,23 @@ def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): der_instance = der_metric(reference, hypothesis) # Normalize the text for spk_idx in range(len(hyp_speaker_words)): - hyp_speaker_words[spk_idx] = hyp_speaker_words[spk_idx].translate(str.maketrans('', '', string.punctuation)).lower() + hyp_speaker_words[spk_idx] = ( + hyp_speaker_words[spk_idx].translate(str.maketrans('', '', string.punctuation)).lower() + ) cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) if verbose: - logging.info(f"Session ID: {session_id} Chunk ID: {chunk_idx} from 0.0s to {(chunk_idx+1)*chunk_size}s") + logging.info( + f"Session ID: {session_id} Chunk ID: {chunk_idx} from 0.0s to {(chunk_idx+1)*chunk_size}s" + ) logging.info(f"DER: {abs(der_metric)*100:.2f}%, cpWER: {cpWER*100:.2f}%") der_list.append(abs(der_metric) * 100) cpwer_list.append(cpWER * 100) return der_list, cpwer_list - + + class OfflineDiarWithASR: """ A class designed for performing ASR and diarization together. @@ -1020,7 +1035,8 @@ def _save_VAD_labels_list(self, word_ts_dict: Dict[str, Dict[str, List[float]]]) @staticmethod def get_speech_labels_from_decoded_prediction( - input_word_ts: List[float], nonspeech_threshold: float, + input_word_ts: List[float], + nonspeech_threshold: float, ) -> List[float]: """ Extract speech labels from the ASR output (decoded predictions) @@ -1222,7 +1238,9 @@ def _get_the_closest_silence_start( return cursor def _compensate_word_ts_list( - self, audio_file_list: List[str], word_ts_dict: Dict[str, List[float]], + self, + audio_file_list: List[str], + word_ts_dict: Dict[str, List[float]], ) -> Dict[str, List[List[float]]]: """ Compensate the word timestamps based on the VAD output. @@ -1364,7 +1382,7 @@ def get_word_level_json_list( Example: >>> word_ts = [[0.0, 0.04], [0.64, 0.68], [0.84, 0.88], ...] - + word_ts_refined (list): Dictionary containing the refined (end point fixed) word timestamps based on hypothesis word timestamps. Indexed by unique IDs. @@ -1377,9 +1395,9 @@ def get_word_level_json_list( List containing word by word dictionary containing word, timestamps and speaker labels. Example: - >>> [{'word': 'right', 'start_time': 0.0, 'end_time': 0.04, 'speaker': 'speaker_0'}, - {'word': 'and', 'start_time': 0.64, 'end_time': 0.68, 'speaker': 'speaker_1'}, - {'word': 'i', 'start_time': 0.84, 'end_time': 0.88, 'speaker': 'speaker_1'}, + >>> [{'word': 'right', 'start_time': 0.0, 'end_time': 0.04, 'speaker': 'speaker_0'}, + {'word': 'and', 'start_time': 0.64, 'end_time': 0.68, 'speaker': 'speaker_1'}, + {'word': 'i', 'start_time': 0.84, 'end_time': 0.88, 'speaker': 'speaker_1'}, ...] """ if word_rfnd_ts is None: @@ -1399,7 +1417,10 @@ def get_word_level_json_list( return word_dict_seq_list def _make_json_output( - self, uniq_id: str, diar_labels: List[str], word_dict_seq_list: List[Dict[str, float]], + self, + uniq_id: str, + diar_labels: List[str], + word_dict_seq_list: List[Dict[str, float]], ) -> Dict[str, Dict[str, str]]: """ Generate json output files and transcripts from the ASR and diarization results. @@ -1448,7 +1469,9 @@ def _make_json_output( } """ logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ") - session_trans_dict, gecko_dict, audacity_label_words, sentences = get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels) + session_trans_dict, gecko_dict, audacity_label_words, sentences = get_session_trans_dict( + uniq_id, word_dict_seq_list, diar_labels + ) self._write_and_log(uniq_id, session_trans_dict, audacity_label_words, gecko_dict, sentences) return session_trans_dict @@ -1637,7 +1660,7 @@ def evaluate( wer_results['total']['average_cpWER'] = word_error_rate(hypotheses=hyps_spk, references=refs_spk) wer_results['total']['average_WER'] = word_error_rate(hypotheses=mix_hypotheses, references=mix_references) - for (uniq_id, cpWER, WER) in zip(uniq_id_list, cpWER_values, WER_values): + for uniq_id, cpWER, WER in zip(uniq_id_list, cpWER_values, WER_values): # Save session-level cpWER and WER values wer_results[uniq_id] = {} wer_results[uniq_id]['cpWER'] = cpWER @@ -1691,8 +1714,6 @@ def write_session_level_result_in_csv( except IOError: logging.info("I/O error has occurred while writing a csv file.") - - def _write_and_log( self, uniq_id: str, @@ -1755,4 +1776,4 @@ def print_errors(der_results: Dict[str, Dict[str, float]], wer_results: Dict[str \nWER : {wer_results['total']['average_WER']:.4f}" ) else: - logging.info(DER_info) \ No newline at end of file + logging.info(DER_info) diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index fd47d5ae8cea..68e7a4d48434 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -12,40 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os, json -from typing import Optional, List, Tuple, Dict, Any +import itertools +import json +import os +import time from collections import OrderedDict from copy import deepcopy +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple import torch +from lhotse.dataset.collation import collate_matrices from omegaconf import DictConfig -from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis - -from nemo.collections.asr.parts.utils.diarization_utils import read_seglst, OnlineEvaluation -from nemo.utils import logging - +from nemo.collections.asr.data.audio_to_diar_label import extract_frame_info_from_rttm, get_frame_targets_from_rttm from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel - -from nemo.collections.asr.parts.utils.speaker_utils import ( -audio_rttm_map as get_audio_rttm_map, -rttm_to_labels, -get_uniqname_from_filepath, -) +from nemo.collections.asr.modules.sortformer_modules import StreamingSortformerState from nemo.collections.asr.parts.utils.diarization_utils import ( -print_sentences, -get_color_palette, -write_txt, + OnlineEvaluation, + get_color_palette, + print_sentences, + read_seglst, + write_txt, ) -from nemo.collections.asr.data.audio_to_diar_label import get_frame_targets_from_rttm, extract_frame_info_from_rttm -from nemo.collections.asr.modules.sortformer_modules import StreamingSortformerState - - -from lhotse.dataset.collation import collate_matrices -import itertools +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map as get_audio_rttm_map +from nemo.collections.asr.parts.utils.speaker_utils import get_uniqname_from_filepath, rttm_to_labels +from nemo.utils import logging -import time -from functools import wraps def measure_eta(func): """ @@ -57,6 +51,7 @@ def measure_eta(func): Returns: callable: The wrapped function. """ + @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() # Record the start time @@ -65,8 +60,10 @@ def wrapper(*args, **kwargs): eta = end_time - start_time # Calculate the elapsed time logging.info(f"[ Step-{kwargs['step_num']} ] for '{func.__name__}': {eta:.4f} seconds") # Print the ETA return result # Return the original function's result + return wrapper + def get_multi_talker_samples_from_manifest(cfg, manifest_file: str, feat_per_sec: float, max_spks: int): """ Get the multi-talker samples from the manifest file and save it to a list named 'samples'. @@ -124,12 +121,12 @@ def get_multi_talker_samples_from_manifest(cfg, manifest_file: str, feat_per_sec def setup_diarization_model(cfg: DictConfig, map_location: Optional[str] = None) -> SortformerEncLabelModel: """Setup model from cfg and return diarization model and model name for next step""" if cfg.diar_model_path.endswith(".ckpt"): - diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.diar_model_path, - map_location=map_location, strict=False) + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.diar_model_path, map_location=map_location, strict=False + ) model_name = os.path.splitext(os.path.basename(cfg.diar_model_path))[0] elif cfg.diar_model_path.endswith(".nemo"): - diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, - map_location=map_location) + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, map_location=map_location) model_name = os.path.splitext(os.path.basename(cfg.diar_model_path))[0] elif cfg.diar_pretrained_name.startswith("nvidia/"): diar_model = SortformerEncLabelModel.from_pretrained(cfg.diar_pretrained_name) @@ -138,6 +135,7 @@ def setup_diarization_model(cfg: DictConfig, map_location: Optional[str] = None) raise ValueError("cfg.diar_model_path must end with.ckpt or.nemo!") return diar_model, model_name + def write_seglst(output_filepath: str, seglst_list: list) -> None: """ Write the segmentation list to a file. @@ -149,6 +147,7 @@ def write_seglst(output_filepath: str, seglst_list: list) -> None: with open(output_filepath, "w", encoding="utf-8") as f: f.write(json.dumps(seglst_list, indent=2) + "\n") + def get_new_sentence_dict( speaker: str, start_time: float, @@ -190,19 +189,15 @@ def calc_drop_extra_pre_encoded(asr_model: SortformerEncLabelModel, step_num: in Returns: int: The number of extra tokens to drop. """ - # for the first step there is no need to drop any tokens + # for the first step there is no need to drop any tokens # after the downsampling as no caching is being used if step_num == 0 and not pad_and_drop_preencoded: return 0 else: return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded -def fix_frame_time_step( - cfg: Any, - new_tokens: List[str], - new_words: List[str], - frame_inds_seq: List[int] - ) -> List[int]: + +def fix_frame_time_step(cfg: Any, new_tokens: List[str], new_words: List[str], frame_inds_seq: List[int]) -> List[int]: """ Adjust the frame indices sequence to match the length of new tokens. @@ -240,6 +235,7 @@ def fix_frame_time_step( ) return frame_inds_seq + def get_simulated_softmax(cfg, speaker_sigmoid: torch.Tensor) -> torch.Tensor: """ Simulate the softmax operation for speaker diarization. @@ -264,16 +260,17 @@ def get_simulated_softmax(cfg, speaker_sigmoid: torch.Tensor) -> torch.Tensor: else: speaker_softmax = speaker_sigmoid / sigmoid_sum speaker_softmax = speaker_softmax.cpu() - speaker_softmax[cfg.max_num_of_spks:] = 0.0 + speaker_softmax[cfg.max_num_of_spks :] = 0.0 return speaker_softmax + def get_word_dict_content_offline( cfg: Any, word: str, word_index: int, diar_pred_out: torch.Tensor, time_stt_end_tuple: Tuple[int], - frame_len: float = 0.08 + frame_len: float = 0.08, ) -> Dict[str, Any]: """ Generate a dictionary containing word information and speaker diarization results. @@ -307,19 +304,22 @@ def get_word_dict_content_offline( speaker_sigmoid = diar_pred_out[stt_p:end_p, :].mean(dim=0) speaker_softmax = get_simulated_softmax(cfg, speaker_sigmoid) - speaker_softmax[cfg.max_num_of_spks:] = 0.0 + speaker_softmax[cfg.max_num_of_spks :] = 0.0 spk_id = speaker_softmax.argmax().item() stt_sec, end_sec = frame_stt * frame_len, frame_end * frame_len - word_dict = {"word": word, - "word_index": word_index, - 'frame_stt': frame_stt, - 'frame_end': frame_end, - 'start_time': round(stt_sec, 3), - 'end_time': round(end_sec, 3), - 'speaker': f"speaker_{spk_id}", - 'speaker_softmax': speaker_softmax} + word_dict = { + "word": word, + "word_index": word_index, + 'frame_stt': frame_stt, + 'frame_end': frame_end, + 'start_time': round(stt_sec, 3), + 'end_time': round(end_sec, 3), + 'speaker': f"speaker_{spk_id}", + 'speaker_softmax': speaker_softmax, + } return word_dict + def get_word_dict_content_online( cfg: Any, word: str, @@ -328,7 +328,7 @@ def get_word_dict_content_online( token_group: List[str], frame_inds_seq: List[int], time_step_local_offset: int, - frame_len: float = 0.08 + frame_len: float = 0.08, ) -> Dict[str, Any]: """ Generate a dictionary containing word information and speaker diarization results. @@ -371,25 +371,24 @@ def get_word_dict_content_online( speaker_sigmoid = diar_pred_out_stream[stt_p:end_p, :].mean(dim=0) speaker_softmax = get_simulated_softmax(cfg, speaker_sigmoid) - speaker_softmax[cfg.max_num_of_spks:] = 0.0 + speaker_softmax[cfg.max_num_of_spks :] = 0.0 spk_id = speaker_softmax.argmax().item() stt_sec, end_sec = frame_stt * frame_len, frame_end * frame_len - word_dict = {"word": word, - "word_index": word_index, - 'frame_stt': frame_stt, - 'frame_end': frame_end, - 'start_time': round(stt_sec, 3), - 'end_time': round(end_sec, 3), - 'speaker': f"speaker_{spk_id}", - 'speaker_softmax': speaker_softmax} + word_dict = { + "word": word, + "word_index": word_index, + 'frame_stt': frame_stt, + 'frame_end': frame_end, + 'start_time': round(stt_sec, 3), + 'end_time': round(end_sec, 3), + 'speaker': f"speaker_{spk_id}", + 'speaker_softmax': speaker_softmax, + } return word_dict + def get_multitoken_words( - cfg, - word_and_ts_seq: Dict[str, List], - word_seq: List[str], - new_words: List[str], - fix_prev_words_count: int = 5 + cfg, word_and_ts_seq: Dict[str, List], word_seq: List[str], new_words: List[str], fix_prev_words_count: int = 5 ) -> Dict[str, List]: """ Fix multi-token words that were not fully captured by the previous chunk window. @@ -418,11 +417,9 @@ def get_multitoken_words( word_and_ts_seq["words"][-fix_prev_words_count + ct]["word"] = prev_word return word_and_ts_seq + def append_word_and_ts_seq( - cfg: Any, - word_idx_offset: int, - word_and_ts_seq: Dict[str, Any], - word_dict: Dict[str, Any] + cfg: Any, word_idx_offset: int, word_and_ts_seq: Dict[str, Any], word_dict: Dict[str, Any] ) -> tuple[int, Dict[str, Any]]: """ Append the word dictionary to the word and time-stamp sequence. @@ -466,7 +463,9 @@ def __init__( self.test_manifest_dict = get_audio_rttm_map(self.cfg.manifest_file) elif self.cfg.audio_file is not None: uniq_id = get_uniqname_from_filepath(filepath=self.cfg.audio_file) - self.test_manifest_dict = {uniq_id: {'audio_filepath': self.cfg.audio_file, 'seglst_filepath': None, 'rttm_filepath': None}} + self.test_manifest_dict = { + uniq_id: {'audio_filepath': self.cfg.audio_file, 'seglst_filepath': None, 'rttm_filepath': None} + } else: raise ValueError("One of the audio_file and manifest_file should be non-empty!") @@ -511,21 +510,22 @@ def _init_evaluator(self): """ self.online_evaluators, self._word_and_ts_seq = [], {} for _, (uniq_id, data_dict) in enumerate(self.test_manifest_dict.items()): - uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id - self._word_and_ts_seq[uniq_id] = {"words": [], - "buffered_words": [], - "token_frame_index": [], - "offset_count": 0, - "status": "success", - "sentences": None, - "last_word_index": 0, - "speaker_count": None, - "transcription": None, - "max_spk_probs": [], - "word_window_seq": [], - "speaker_count_buffer": [], - "sentence_memory": {}, - } + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + self._word_and_ts_seq[uniq_id] = { + "words": [], + "buffered_words": [], + "token_frame_index": [], + "offset_count": 0, + "status": "success", + "sentences": None, + "last_word_index": 0, + "speaker_count": None, + "transcription": None, + "max_spk_probs": [], + "word_window_seq": [], + "speaker_count_buffer": [], + "sentence_memory": {}, + } if 'seglst_filepath' in data_dict and data_dict['seglst_filepath'] is not None: ref_seglst = read_seglst(data_dict['seglst_filepath']) @@ -537,12 +537,14 @@ def _init_evaluator(self): else: ref_rttm_labels = None - eval_instance = OnlineEvaluation(ref_seglst=ref_seglst, - ref_rttm_labels=ref_rttm_labels, - hyp_seglst=None, - collar=0.25, - ignore_overlap=False, - verbose=True) + eval_instance = OnlineEvaluation( + ref_seglst=ref_seglst, + ref_rttm_labels=ref_rttm_labels, + hyp_seglst=None, + collar=0.25, + ignore_overlap=False, + verbose=True, + ) self.online_evaluators.append(eval_instance) def _get_offset_sentence(self, session_trans_dict: Dict[str, Any], offset: int) -> Dict[str, Any]: @@ -557,11 +559,13 @@ def _get_offset_sentence(self, session_trans_dict: Dict[str, Any], offset: int) (Dict): Dictionary containing offset sentence information. """ word_dict = session_trans_dict['words'][offset] - return {'session_id': session_trans_dict['uniq_id'], - 'speaker': word_dict['speaker'], - 'start_time': word_dict['start_time'], - 'end_time': word_dict['end_time'], - 'words': f"{word_dict['word']} "} + return { + 'session_id': session_trans_dict['uniq_id'], + 'speaker': word_dict['speaker'], + 'start_time': word_dict['start_time'], + 'end_time': word_dict['end_time'], + 'words': f"{word_dict['word']} ", + } def _get_sentence(self, word_dict: Dict[str, Any]) -> Dict[str, Any]: """ @@ -570,10 +574,12 @@ def _get_sentence(self, word_dict: Dict[str, Any]) -> Dict[str, Any]: Args: word_dict (Dict[str, Any]): Dictionary containing word-related information. """ - return {'speaker': word_dict['speaker'], - 'start_time': word_dict['start_time'], - 'end_time': word_dict['end_time'], - 'words': ''} + return { + 'speaker': word_dict['speaker'], + 'start_time': word_dict['start_time'], + 'end_time': word_dict['end_time'], + 'words': '', + } def get_sentences_values(self, session_trans_dict: dict, sentence_render_length: int): """ @@ -591,11 +597,9 @@ def get_sentences_values(self, session_trans_dict: dict, sentence_render_length: sentence = self._get_offset_sentence(session_trans_dict=session_trans_dict, offset=0) sentences = [] session_trans_dict['last_word_index'] = stt_word_index - session_trans_dict['sentence_memory'].update({stt_word_index: - (deepcopy(sentences), - deepcopy(sentence), - sentence['speaker'] - )}) + session_trans_dict['sentence_memory'].update( + {stt_word_index: (deepcopy(sentences), deepcopy(sentence), sentence['speaker'])} + ) prev_speaker = session_trans_dict['words'][stt_word_index]['speaker'] else: (_sentences, _sentence, prev_speaker) = session_trans_dict['sentence_memory'][stt_word_index] @@ -622,10 +626,7 @@ def get_sentences_values(self, session_trans_dict: dict, sentence_render_length: return session_trans_dict def merge_transcript_and_speakers( - self, - test_manifest_dict: dict, - asr_hypotheses: List[Hypothesis], - diar_pred_out: torch.Tensor + self, test_manifest_dict: dict, asr_hypotheses: List[Hypothesis], diar_pred_out: torch.Tensor ) -> Tuple[List[str], Dict[str, Dict[str, Any]]]: """ Merge the transcript and speakers and generate real-time scripts if the config is set. @@ -642,24 +643,30 @@ def merge_transcript_and_speakers( transcribed_speaker_texts = [None] * len(test_manifest_dict) for idx, (uniq_id, _) in enumerate(test_manifest_dict.items()): - uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id - if not len( asr_hypotheses[idx].text) == 0: + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + if not len(asr_hypotheses[idx].text) == 0: # Get the word-level dictionaries for each word in the chunk - self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_offline(uniq_id=uniq_id, - diar_pred_out=diar_pred_out[idx].squeeze(0), - asr_hypothesis=asr_hypotheses[idx], - word_and_ts_seq=self._word_and_ts_seq[uniq_id], - ) + self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_offline( + uniq_id=uniq_id, + diar_pred_out=diar_pred_out[idx].squeeze(0), + asr_hypothesis=asr_hypotheses[idx], + word_and_ts_seq=self._word_and_ts_seq[uniq_id], + ) if len(self._word_and_ts_seq[uniq_id]["words"]) > 0: - self._word_and_ts_seq[uniq_id] = self.get_sentences_values(session_trans_dict=self._word_and_ts_seq[uniq_id], - sentence_render_length=self._sentence_render_length) + self._word_and_ts_seq[uniq_id] = self.get_sentences_values( + session_trans_dict=self._word_and_ts_seq[uniq_id], + sentence_render_length=self._sentence_render_length, + ) if self.cfg.generate_realtime_scripts: - transcribed_speaker_texts[idx] = \ - print_sentences(sentences=self._word_and_ts_seq[uniq_id]["sentences"], + transcribed_speaker_texts[idx] = print_sentences( + sentences=self._word_and_ts_seq[uniq_id]["sentences"], color_palette=get_color_palette(), - params=self.cfg) - write_txt(f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), - transcribed_speaker_texts[idx].strip()) + params=self.cfg, + ) + write_txt( + f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), + transcribed_speaker_texts[idx].strip(), + ) return transcribed_speaker_texts, self._word_and_ts_seq def get_frame_and_words_offline( @@ -684,14 +691,15 @@ def get_frame_and_words_offline( word_and_ts_seq['uniq_id'] = uniq_id for word_index, hyp_word_dict in enumerate(asr_hypothesis.timestamp['word']): - time_stt_end_tuple=(hyp_word_dict['start_offset'], hyp_word_dict['end_offset']) - word_dict = get_word_dict_content_offline(cfg=self.cfg, - word=hyp_word_dict['word'], - word_index=word_index, - diar_pred_out=diar_pred_out, - time_stt_end_tuple=time_stt_end_tuple, - frame_len=self._frame_len_sec - ) + time_stt_end_tuple = (hyp_word_dict['start_offset'], hyp_word_dict['end_offset']) + word_dict = get_word_dict_content_offline( + cfg=self.cfg, + word=hyp_word_dict['word'], + word_index=word_index, + diar_pred_out=diar_pred_out, + time_stt_end_tuple=time_stt_end_tuple, + frame_len=self._frame_len_sec, + ) word_and_ts_seq["words"].append(word_dict) word_and_ts_seq["speaker_count_buffer"].append(word_dict["speaker"]) word_and_ts_seq["word_window_seq"].append(word_dict['word']) @@ -723,7 +731,7 @@ def get_frame_and_words_online( """ offset = step_num * self._frame_hop_length word_seq = previous_hypothesis.text.split() - new_words = word_seq[word_and_ts_seq["offset_count"]:] + new_words = word_seq[word_and_ts_seq["offset_count"] :] new_token_group = self.asr_model.tokenizer.text_to_tokens(new_words) new_tokens = list(itertools.chain(*new_token_group)) frame_inds_seq = (torch.tensor(previous_hypothesis.timestamp) + offset).tolist() @@ -737,30 +745,31 @@ def get_frame_and_words_online( word_and_ts_seq["offset_count"] += 1 time_step_local_offset, word_idx_offset = 0, 0 - word_and_ts_seq = get_multitoken_words(cfg=self.cfg, - word_and_ts_seq=word_and_ts_seq, - word_seq=word_seq, - new_words=new_words, - fix_prev_words_count=self._fix_prev_words_count - ) + word_and_ts_seq = get_multitoken_words( + cfg=self.cfg, + word_and_ts_seq=word_and_ts_seq, + word_seq=word_seq, + new_words=new_words, + fix_prev_words_count=self._fix_prev_words_count, + ) # Get the FIFO queue preds to word_and_ts_seq for local_idx, (token_group, word) in enumerate(zip(new_token_group, new_words)): - word_dict = get_word_dict_content_online(cfg=self.cfg, - word=word, - word_index= ( len(word_and_ts_seq["words"]) + local_idx), - diar_pred_out_stream=diar_pred_out_stream, - token_group=token_group, - frame_inds_seq=frame_inds_seq, - time_step_local_offset=time_step_local_offset, - frame_len=self._frame_len_sec - ) + word_dict = get_word_dict_content_online( + cfg=self.cfg, + word=word, + word_index=(len(word_and_ts_seq["words"]) + local_idx), + diar_pred_out_stream=diar_pred_out_stream, + token_group=token_group, + frame_inds_seq=frame_inds_seq, + time_step_local_offset=time_step_local_offset, + frame_len=self._frame_len_sec, + ) # Count the number of speakers in the word window time_step_local_offset += len(token_group) - word_idx_offset, word_and_ts_seq = append_word_and_ts_seq(cfg=self.cfg, - word_idx_offset=word_idx_offset, - word_and_ts_seq=word_and_ts_seq, - word_dict=word_dict) + word_idx_offset, word_and_ts_seq = append_word_and_ts_seq( + cfg=self.cfg, word_idx_offset=word_idx_offset, word_and_ts_seq=word_and_ts_seq, word_dict=word_dict + ) return word_and_ts_seq def _add_speaker_transcriptions( @@ -768,7 +777,7 @@ def _add_speaker_transcriptions( transcriptions: list, speaker_transcriptions: List[str], word_and_ts_seq: Dict[str, Dict[str, Any]], - test_manifest_dict: dict + test_manifest_dict: dict, ) -> Tuple[List[Hypothesis], List[Hypothesis]]: """ Add speaker tagging into the transcriptions generated from an ASR model. @@ -788,7 +797,7 @@ def _add_speaker_transcriptions( """ trans_hyp, _ = transcriptions for sess_idx, (uniq_id, _) in enumerate(test_manifest_dict.items()): - uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id + uniq_id = uniq_id.split(".")[0] # Make sure there is no "." in the uniq_id if speaker_transcriptions[sess_idx] is not None: trans_hyp[sess_idx].text = speaker_transcriptions[sess_idx] speaker_added_word_dicts = [] @@ -831,19 +840,18 @@ def perform_offline_stt_spk(self, override_cfg: Dict[str, Any]): override_config=override_cfg, ) best_hyp, _ = transcriptions - _, pred_tensors = self.diar_model.diarize(audio=self.cfg.manifest_file, - include_tensor_outputs=True) + _, pred_tensors = self.diar_model.diarize(audio=self.cfg.manifest_file, include_tensor_outputs=True) speaker_transcriptions, word_and_ts_seq = self.merge_transcript_and_speakers( - test_manifest_dict=self.diar_model._diarize_audio_rttm_map, - asr_hypotheses=best_hyp, - diar_pred_out=pred_tensors - ) + test_manifest_dict=self.diar_model._diarize_audio_rttm_map, + asr_hypotheses=best_hyp, + diar_pred_out=pred_tensors, + ) transcriptions = self._add_speaker_transcriptions( - transcriptions=transcriptions, - speaker_transcriptions=speaker_transcriptions, - word_and_ts_seq=word_and_ts_seq, - test_manifest_dict=self.diar_model._diarize_audio_rttm_map, - ) + transcriptions=transcriptions, + speaker_transcriptions=speaker_transcriptions, + word_and_ts_seq=word_and_ts_seq, + test_manifest_dict=self.diar_model._diarize_audio_rttm_map, + ) return transcriptions def generate_seglst_dicts_from_serial_streaming(self, samples: List[Dict[str, Any]]): @@ -866,7 +874,7 @@ def generate_seglst_dicts_from_serial_streaming(self, samples: List[Dict[str, An start_time=float(sentence_dict['start_time']), end_time=float(sentence_dict['end_time']), text=sentence_dict["words"], - session_id=session_id + session_id=session_id, ) self.instance_manager.seglst_dict_list.append(seglst_dict) @@ -889,8 +897,9 @@ def generate_seglst_dicts_from_parallel_streaming(self, samples: List[Dict[str, start_time=seg['start_time'], end_time=seg['end_time'], text=seg['words'], - session_id=uniq_id - ) for seg in asr_state.seglsts + session_id=uniq_id, + ) + for seg in asr_state.seglsts ] seglsts = sorted(seglsts, key=lambda x: x['start_time']) self.instance_manager.seglst_dict_list.extend(seglsts) @@ -909,9 +918,11 @@ def _find_active_speakers(self, diar_preds: torch.Tensor, n_active_speakers_per_ if diar_preds.ndim != 3: raise ValueError(f"diar_preds must be 3D (B, T, N), got shape {diar_preds.shape}") if n_active_speakers_per_stream > diar_preds.shape[2]: - raise ValueError(f"n_active_speakers_per_stream ({n_active_speakers_per_stream}) " - f"> available speakers ({diar_preds.shape[2]})") - max_probs = torch.max(diar_preds, dim=1).values # (B, T, N) --> (B, N) + raise ValueError( + f"n_active_speakers_per_stream ({n_active_speakers_per_stream}) " + f"> available speakers ({diar_preds.shape[2]})" + ) + max_probs = torch.max(diar_preds, dim=1).values # (B, T, N) --> (B, N) top_values, top_indices = torch.topk(max_probs, k=n_active_speakers_per_stream, dim=1) masks = top_values > 0.5 @@ -920,7 +931,9 @@ def _find_active_speakers(self, diar_preds: torch.Tensor, n_active_speakers_per_ speaker_ids_list.append(sorted(speaker_ids[mask].tolist())) return speaker_ids_list - def forward_pre_encoded(self, audio_signal: torch.Tensor, length: torch.Tensor, drop_extra_pre_encoded: int=0) -> None: + def forward_pre_encoded( + self, audio_signal: torch.Tensor, length: torch.Tensor, drop_extra_pre_encoded: int = 0 + ) -> None: """ Forward the pre-encoded features through the ASR model. @@ -933,22 +946,18 @@ def forward_pre_encoded(self, audio_signal: torch.Tensor, length: torch.Tensor, audio_signal (torch.Tensor): The pre-encoded audio signal. length (torch.Tensor): The length of the pre-encoded audio signal. """ - audio_signal = torch.transpose(audio_signal, 1, 2) # (B, T, D) -> (B, D, T) + audio_signal = torch.transpose(audio_signal, 1, 2) # (B, T, D) -> (B, D, T) audio_signal, length = self.asr_model.encoder.pre_encode(x=audio_signal, lengths=length) length = length.to(torch.int64) # `self.streaming_cfg` is set by setup_streaming_cfg(), called in the init if drop_extra_pre_encoded: - audio_signal = audio_signal[:, drop_extra_pre_encoded :, :] + audio_signal = audio_signal[:, drop_extra_pre_encoded:, :] length = (length - drop_extra_pre_encoded).clamp(min=0) return audio_signal, length def mask_features( - self, - chunk_audio: torch.Tensor, - mask: torch.Tensor, - threshold: float = 0.5, - mask_value: float = -16.6355 + self, chunk_audio: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5, mask_value: float = -16.6355 ) -> torch.Tensor: """ Mask the features of the chunk audio. @@ -963,7 +972,9 @@ def mask_features( masked_chunk_audio (torch.Tensor): The masked chunk audio. """ if chunk_audio.ndim != 3: - raise ValueError(f"chunk_audio must be 3D (B, C, T), got {chunk_audio.ndim}D with shape {chunk_audio.shape}") + raise ValueError( + f"chunk_audio must be 3D (B, C, T), got {chunk_audio.ndim}D with shape {chunk_audio.shape}" + ) if mask.ndim != 2: raise ValueError(f"mask must be 2D (B, T), got {mask.ndim}D with shape {mask.shape}") if chunk_audio.shape[0] != mask.shape[0]: @@ -973,7 +984,7 @@ def mask_features( if mask.shape[1] > chunk_audio.shape[2]: logging.warning(f"Mask shape {mask.shape} is greater than chunk_audio shape {chunk_audio.shape}") - mask = mask[:, :chunk_audio.shape[2]] + mask = mask[:, : chunk_audio.shape[2]] elif mask.shape[1] < chunk_audio.shape[2]: logging.warning(f"Mask shape {mask.shape} is less than chunk_audio shape {chunk_audio.shape}") mask = torch.nn.functional.pad(mask, (chunk_audio.shape[2] - mask.shape[1], 0), mode='constant', value=0) @@ -999,7 +1010,7 @@ def mask_preencode(self, chunk_audio: torch.Tensor, mask: torch.Tensor, threshol if mask.shape[1] > chunk_audio.shape[1]: logging.warning(f"Mask shape {mask.shape} is greater than chunk_audio shape {chunk_audio.shape}") - mask = mask[:, :chunk_audio.shape[1]] + mask = mask[:, : chunk_audio.shape[1]] elif mask.shape[1] < chunk_audio.shape[1]: logging.warning(f"Mask shape {mask.shape} is less than chunk_audio shape {chunk_audio.shape}") mask = torch.nn.functional.pad(mask, (chunk_audio.shape[1] - mask.shape[1], 0), mode='constant', value=0) @@ -1085,12 +1096,12 @@ def perform_serial_streaming_stt_spk( processed_signal_length=chunk_lengths, streaming_state=self.instance_manager.diar_states.streaming_state, total_preds=self.instance_manager.diar_states.diar_pred_out_stream, - drop_extra_pre_encoded=drop_extra_pre_encoded + drop_extra_pre_encoded=drop_extra_pre_encoded, ) self.instance_manager.update_diar_state( diar_pred_out_stream=diar_pred_out_stream, - previous_chunk_preds=diar_pred_out_stream[:, -self._nframes_per_chunk:], - diar_streaming_state=new_streaming_state + previous_chunk_preds=diar_pred_out_stream[:, -self._nframes_per_chunk :], + diar_streaming_state=new_streaming_state, ) else: _, new_chunk_preds = self.get_diar_pred_out_stream(step_num) @@ -1100,22 +1111,28 @@ def perform_serial_streaming_stt_spk( for idx, (uniq_id, _) in enumerate(self.test_manifest_dict.items()): if not (len(previous_hypotheses[idx].text) == 0 and step_num <= self._initial_steps): # Get the word-level dictionaries for each word in the chunk - self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_online(uniq_id=uniq_id, - step_num=step_num, - diar_pred_out_stream=diar_pred_out_stream[idx, :, :], - previous_hypothesis=previous_hypotheses[idx], - word_and_ts_seq=self._word_and_ts_seq[uniq_id], - ) + self._word_and_ts_seq[uniq_id] = self.get_frame_and_words_online( + uniq_id=uniq_id, + step_num=step_num, + diar_pred_out_stream=diar_pred_out_stream[idx, :, :], + previous_hypothesis=previous_hypotheses[idx], + word_and_ts_seq=self._word_and_ts_seq[uniq_id], + ) if len(self._word_and_ts_seq[uniq_id]["words"]) > 0: - self._word_and_ts_seq[uniq_id] = self.get_sentences_values(session_trans_dict=self._word_and_ts_seq[uniq_id], - sentence_render_length=self._sentence_render_length) + self._word_and_ts_seq[uniq_id] = self.get_sentences_values( + session_trans_dict=self._word_and_ts_seq[uniq_id], + sentence_render_length=self._sentence_render_length, + ) if self.cfg.generate_realtime_scripts: - transcribed_speaker_texts[idx] = \ - print_sentences(sentences=self._word_and_ts_seq[uniq_id]["sentences"], + transcribed_speaker_texts[idx] = print_sentences( + sentences=self._word_and_ts_seq[uniq_id]["sentences"], color_palette=get_color_palette(), - params=self.cfg) - write_txt(f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), - transcribed_speaker_texts[idx].strip()) + params=self.cfg, + ) + write_txt( + f'{self.cfg.print_path}'.replace(".sh", f"_{idx}.sh"), + transcribed_speaker_texts[idx].strip(), + ) for batch_idx in range(chunk_audio.shape[0]): self.instance_manager.update_asr_state( @@ -1125,10 +1142,9 @@ def perform_serial_streaming_stt_spk( cache_last_time=cache_last_time[:, batch_idx], cache_last_channel_len=cache_last_channel_len[batch_idx], previous_hypotheses=previous_hypotheses[batch_idx], - previous_pred_out=asr_pred_out_stream[batch_idx] + previous_pred_out=asr_pred_out_stream[batch_idx], ) - @measure_eta def perform_parallel_streaming_stt_spk( self, @@ -1156,7 +1172,6 @@ def perform_parallel_streaming_stt_spk( self.instance_manager.reset(batch_size=chunk_audio.shape[0]) self.instance_manager.to(chunk_audio.device) - # Step 2: diarize or get GT rttms if self.diar_model.rttms_mask_mats is None: new_streaming_state, new_diar_pred_out_stream = self.diar_model.forward_streaming_step( @@ -1164,9 +1179,9 @@ def perform_parallel_streaming_stt_spk( processed_signal_length=chunk_lengths, streaming_state=self.instance_manager.diar_states.streaming_state, total_preds=self.instance_manager.diar_states.diar_pred_out_stream, - drop_extra_pre_encoded=drop_extra_pre_encoded + drop_extra_pre_encoded=drop_extra_pre_encoded, ) - new_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk:] + new_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk :] else: new_diar_pred_out_stream, new_chunk_preds = self.get_diar_pred_out_stream(step_num) @@ -1176,12 +1191,14 @@ def perform_parallel_streaming_stt_spk( self.instance_manager.update_diar_state( diar_pred_out_stream=new_diar_pred_out_stream, previous_chunk_preds=new_chunk_preds, - diar_streaming_state=new_streaming_state + diar_streaming_state=new_streaming_state, ) # Step 4: find active speakers - diar_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk*self._cache_gating_buffer_size:] + diar_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk * self._cache_gating_buffer_size :] if self._cache_gating: - active_speakers = self._find_active_speakers(diar_chunk_preds, n_active_speakers_per_stream=self.n_active_speakers_per_stream) + active_speakers = self._find_active_speakers( + diar_chunk_preds, n_active_speakers_per_stream=self.n_active_speakers_per_stream + ) else: active_speakers = [list(range(self.n_active_speakers_per_stream)) for _ in range(chunk_audio.shape[0])] @@ -1206,7 +1223,7 @@ def perform_parallel_streaming_stt_spk( # skip current chunk if no active speakers are found if active_chunk_audio is None: return - + # Step 6: # 1. mask the non-active speakers for masked ASR # 2. set speaker targets for multitalker ASR @@ -1230,19 +1247,19 @@ def perform_parallel_streaming_stt_spk( cache_last_channel_len, previous_hypotheses, ) = self.asr_model.conformer_stream_step( - processed_signal=active_chunk_audio, - processed_signal_length=active_chunk_lengths, - cache_last_channel=self.instance_manager.active_cache_last_channel, - cache_last_time=self.instance_manager.active_cache_last_time, - cache_last_channel_len=self.instance_manager.active_cache_last_channel_len, - keep_all_outputs=is_buffer_empty, - previous_hypotheses=self.instance_manager.active_previous_hypotheses, - previous_pred_out=self.instance_manager.active_asr_pred_out_stream, - drop_extra_pre_encoded=drop_extra_pre_encoded, - return_transcription=True, - bypass_pre_encode=bypass_pre_encode - ) - + processed_signal=active_chunk_audio, + processed_signal_length=active_chunk_lengths, + cache_last_channel=self.instance_manager.active_cache_last_channel, + cache_last_time=self.instance_manager.active_cache_last_time, + cache_last_channel_len=self.instance_manager.active_cache_last_channel_len, + keep_all_outputs=is_buffer_empty, + previous_hypotheses=self.instance_manager.active_previous_hypotheses, + previous_pred_out=self.instance_manager.active_asr_pred_out_stream, + drop_extra_pre_encoded=drop_extra_pre_encoded, + return_transcription=True, + bypass_pre_encode=bypass_pre_encode, + ) + # Step 8: update ASR states active_id = 0 for batch_idx, speaker_ids in enumerate(active_speakers): @@ -1254,7 +1271,7 @@ def perform_parallel_streaming_stt_spk( cache_last_time[:, active_id], cache_last_channel_len[active_id], previous_hypotheses[active_id], - pred_out_stream[active_id] + pred_out_stream[active_id], ) active_id += 1 @@ -1265,10 +1282,12 @@ def perform_parallel_streaming_stt_spk( if self.cfg.generate_realtime_scripts: for session_idx in self.cfg.print_sample_indices: asr_state = self.instance_manager.batch_asr_states[session_idx] - transcribed_speaker_texts = print_sentences(sentences=asr_state.seglsts, - color_palette=get_color_palette(), - params=self.cfg) - write_txt(f'{self.cfg.print_path.replace(".sh", f"_{session_idx}.sh")}', transcribed_speaker_texts.strip()) + transcribed_speaker_texts = print_sentences( + sentences=asr_state.seglsts, color_palette=get_color_palette(), params=self.cfg + ) + write_txt( + f'{self.cfg.print_path.replace(".sh", f"_{session_idx}.sh")}', transcribed_speaker_texts.strip() + ) class MultiTalkerInstanceManager: @@ -1279,6 +1298,7 @@ class MultiTalkerInstanceManager: batch size for inference. If there are at most N speakers and the batch size is B, then the real batch size for inference is at most B * N. """ + class ASRState: """ ASR state for each instance. @@ -1288,12 +1308,8 @@ class ASRState: The goal of ASR-State class is to handle the ASR cache state between streaming steps. The ASR-states required to perform streaming inference are all included in this class. """ - def __init__( - self, - max_num_of_spks: int = 4, - frame_len_sec: float = 0.08, - sent_break_sec: float = 5.0 - ): + + def __init__(self, max_num_of_spks: int = 4, frame_len_sec: float = 0.08, sent_break_sec: float = 5.0): """ Initialize the ASR-State class with the initial parameters. @@ -1315,7 +1331,7 @@ def __init__( self._frame_len_sec = frame_len_sec self._sent_break_sec = sent_break_sec self._speaker_wise_sentences = {} - self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] + self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] self.seglsts = [] @@ -1324,8 +1340,7 @@ def _reset_speaker_wise_sentences(self): Reset the speaker-wise sentences which will be used to generate the SegLST transcription outputs. """ self._speaker_wise_sentences = {} - self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] - + self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): """ @@ -1343,15 +1358,16 @@ def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] self.previous_pred_out = [None] self.seglsts = [] self._speaker_wise_sentences = {} - self._prev_history_speaker_texts = [ "" for _ in range(self.max_num_of_spks) ] + self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] - def update_asr_state(self, + def update_asr_state( + self, speaker_id, cache_last_channel, cache_last_time, cache_last_channel_len, previous_hypothesis, - previous_pred_out + previous_pred_out, ): """ Update the ASR state with the new ASR cache state. @@ -1436,11 +1452,7 @@ def _is_new_text(self, spk_idx: int, text: str): else: return text.strip() - def _compute_hypothesis_timestamps( - self, - hypothesis: Hypothesis, - offset: float - ) -> Tuple[float, float, bool]: + def _compute_hypothesis_timestamps(self, hypothesis: Hypothesis, offset: float) -> Tuple[float, float, bool]: """ Compute start and end timestamps for a hypothesis based on available timing information. @@ -1498,8 +1510,7 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float): if diff_text is not None: start_time, end_time, sep_flag = self._compute_hypothesis_timestamps( - hypothesis=hypothesis, - offset=offset + hypothesis=hypothesis, offset=offset ) # Get the last end time of the previous sentence or None if no sentences are present @@ -1514,10 +1525,11 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float): # This handles the case where the first character should be assigned to the previous sentence. the_first_char, diff_text = diff_text.strip()[0], diff_text.strip()[1:] self._update_last_sentence(spk_idx=spk_idx, end_time=None, diff_text=the_first_char) - self._speaker_wise_sentences[spk_idx].append(get_new_sentence_dict(speaker=f"speaker_{spk_idx}", - start_time=start_time, - end_time=end_time, - text=diff_text)) + self._speaker_wise_sentences[spk_idx].append( + get_new_sentence_dict( + speaker=f"speaker_{spk_idx}", start_time=start_time, end_time=end_time, text=diff_text + ) + ) # Case 2 - If start_time is less than end_time + sent_break_sec, then we need to update the end_time else: self._update_last_sentence(spk_idx=spk_idx, end_time=end_time, diff_text=diff_text) @@ -1541,7 +1553,8 @@ class DiarState: There is no difference between serial and parallel mode for the diarization state. The goal of Diar-State class is to handle the diarization cache state between streaming steps. """ - def __init__(self, batch_size: int=1, max_num_of_spks: int=4): + + def __init__(self, batch_size: int = 1, max_num_of_spks: int = 4): """ Initialize the Diar-State class with the initial parameters. @@ -1565,12 +1578,13 @@ def to(self, device): self.previous_chunk_preds = self.previous_chunk_preds.to(device) self.streaming_state.to(device) - def __init__(self, + def __init__( + self, asr_model=None, diar_model=None, - batch_size: int=1, - max_num_of_spks: int=4, - sent_break_sec: float=5.0, + batch_size: int = 1, + max_num_of_spks: int = 4, + sent_break_sec: float = 5.0, ): """ Initialize the MultiTalkerInstanceManager class with the initial parameters. @@ -1649,7 +1663,9 @@ def reset(self, batch_size: Optional[int] = None, max_num_of_spks: Optional[int] if len(self.batch_asr_states) > 0: self.previous_asr_states.extend(deepcopy(self.batch_asr_states)) - self.batch_asr_states = [self.ASRState(self.max_num_of_spks, sent_break_sec=self._sent_break_sec) for _ in range(self.batch_size)] + self.batch_asr_states = [ + self.ASRState(self.max_num_of_spks, sent_break_sec=self._sent_break_sec) for _ in range(self.batch_size) + ] for i in range(self.batch_size): self.batch_asr_states[i].reset(self.asr_model.encoder.get_initial_cache_state(batch_size=1)) @@ -1668,11 +1684,11 @@ def add_speaker(self, batch_idx: int, speaker_id: int): speaker_id (int): The speaker id. """ speakers = self.batch_asr_states[batch_idx].get_speakers() - for speaker_index in range(0, speaker_id+1): + for speaker_index in range(0, speaker_id + 1): if speaker_index not in speakers: self.batch_asr_states[batch_idx].add_speaker( speaker_id=speaker_index, - asr_cache_state=self.asr_model.encoder.get_initial_cache_state(batch_size=1) + asr_cache_state=self.asr_model.encoder.get_initial_cache_state(batch_size=1), ) def get_speakers(self, batch_idx: int): @@ -1699,7 +1715,7 @@ def update_diar_state( self, diar_pred_out_stream: torch.Tensor, previous_chunk_preds: torch.Tensor, - diar_streaming_state: StreamingSortformerState + diar_streaming_state: StreamingSortformerState, ): """ Update the diarization state from the diarization step. @@ -1722,7 +1738,7 @@ def update_asr_state( cache_last_time, cache_last_channel_len, previous_hypotheses, - previous_pred_out + previous_pred_out, ): """ A function to update the ASR state with the new ASR cache state. @@ -1746,7 +1762,7 @@ def update_asr_state( cache_last_time, cache_last_channel_len, previous_hypotheses, - previous_pred_out + previous_pred_out, ) def get_active_speakers_info(self, active_speakers, chunk_audio, chunk_lengths): @@ -1769,15 +1785,23 @@ def get_active_speakers_info(self, active_speakers, chunk_audio, chunk_lengths): self._active_chunk_lengths.append(chunk_lengths[batch_idx]) self._active_speaker_targets.append(self.diar_states.previous_chunk_preds[batch_idx, :, speaker_id]) inactive_speaker_ids = [i for i in range(len(speaker_ids)) if i != speaker_id] - self._inactive_speaker_targets.append((self.diar_states.previous_chunk_preds[batch_idx, :, inactive_speaker_ids] > 0.5).sum(dim=-1) > 0) + self._inactive_speaker_targets.append( + (self.diar_states.previous_chunk_preds[batch_idx, :, inactive_speaker_ids] > 0.5).sum(dim=-1) > 0 + ) if speaker_id not in self.batch_asr_states[batch_idx].get_speakers(): self.add_speaker(batch_idx, speaker_id) - self._active_previous_hypotheses.append(self.batch_asr_states[batch_idx].previous_hypothesis[speaker_id]) + self._active_previous_hypotheses.append( + self.batch_asr_states[batch_idx].previous_hypothesis[speaker_id] + ) self._active_asr_pred_out_stream.append(self.batch_asr_states[batch_idx].previous_pred_out[speaker_id]) - self._active_cache_last_channel.append(self.batch_asr_states[batch_idx].cache_last_channel[:, speaker_id]) + self._active_cache_last_channel.append( + self.batch_asr_states[batch_idx].cache_last_channel[:, speaker_id] + ) self._active_cache_last_time.append(self.batch_asr_states[batch_idx].cache_last_time[:, speaker_id]) - self._active_cache_last_channel_len.append(self.batch_asr_states[batch_idx].cache_last_channel_len[speaker_id]) + self._active_cache_last_channel_len.append( + self.batch_asr_states[batch_idx].cache_last_channel_len[speaker_id] + ) if len(self._active_chunk_audio) == 0: return None, None, None, None @@ -1803,4 +1827,4 @@ def update_seglsts(self, offset: int): offset (int): The offset of the chunk. """ for asr_state in self.batch_asr_states: - asr_state.update_sessionwise_seglsts_for_parallel(offset=offset) \ No newline at end of file + asr_state.update_sessionwise_seglsts_for_parallel(offset=offset) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 7bbbf109b48e..3a5972c29711 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -26,6 +26,7 @@ from lhotse.cut import Cut, MixedCut, PaddingCut from omegaconf import DictConfig, ListConfig, OmegaConf +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator from nemo.collections.common.data.lhotse.nemo_adapters import ( LazyNeMoIterator, LazyNeMoTarredIterator, @@ -42,7 +43,6 @@ TextTurn, ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]: @@ -833,7 +833,7 @@ def read_multi_speaker_simulator(config: DictConfig) -> tuple[CutSet, bool]: ) is_tarred = config.get("is_tarred", False) return multi_speaker_cuts, is_tarred - + def mux( *cutsets: CutSet, From 031e5e2e0374f13d7749a13b6c92c46e618d05bb Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Thu, 9 Oct 2025 14:27:55 -0700 Subject: [PATCH 03/29] Solving CodeQL comments and remove some unused functions/classes Signed-off-by: Weiqing Wang --- ...ech_to_text_multitalker_streaming_infer.py | 24 +- .../asr/data/audio_to_text_lhotse_speaker.py | 9 - .../asr/models/multitalker_asr_models.py | 8 +- .../parts/mixins/multitalker_asr_mixins.py | 4 +- .../asr/parts/utils/asr_multispeaker_utils.py | 240 +----------------- .../asr/parts/utils/data_simulation_utils.py | 9 +- .../parts/utils/multispk_transcribe_utils.py | 1 - 7 files changed, 14 insertions(+), 281 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index c2a7453db7ff..8ca29f34490f 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -12,39 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import json import math -import os import time -from collections import OrderedDict -from copy import deepcopy from dataclasses import dataclass, field, is_dataclass -from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Union import pytorch_lightning as pl import torch -from lhotse.dataset.collation import collate_matrices from omegaconf import OmegaConf, open_dict import nemo.collections.asr as nemo_asr -from nemo.collections.asr.data.audio_to_diar_label import extract_frame_info_from_rttm, get_frame_targets_from_rttm from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel -from nemo.collections.asr.parts.utils.diarization_utils import ( - OnlineEvaluation, - get_color_palette, - print_sentences, - read_seglst, - write_txt, -) from nemo.collections.asr.parts.utils.multispk_transcribe_utils import ( SpeakerTaggedASR, get_multi_talker_samples_from_manifest, ) -from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis -from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map as get_audio_rttm_map -from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer from nemo.core.config import hydra_runner from nemo.utils import logging @@ -298,14 +281,11 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate if cfg.audio_file is not None and cfg.manifest_file is not None: - logging.warning("Both audio_file and manifest_file are specified. audio_file will be used with top priority.") - input_type = "audio_file" + logging.warning("Both audio_file and manifest_file are specified. Audio_file will be used with top priority.") elif cfg.audio_file is not None: logging.info("audio_file is specified. Using audio_file as input.") - input_type = "audio_file" elif cfg.manifest_file is not None: logging.info("manifest_file is specified. Using manifest_file as input.") - input_type = "manifest_file" else: raise ValueError("One of audio_file or manifest_file must be specified!") diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 06a28769536d..04d43a5ecf93 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -13,25 +13,16 @@ # limitations under the License. import random -import re from typing import Dict, Optional, Tuple -import numpy as np -import soundfile import torch.utils.data -from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment, SupervisionSet -from lhotse.cut import MixedCut, MixTrack, MonoCut, PaddingCut from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_matrices, collate_vectors -from lhotse.utils import compute_num_samples from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - get_hidden_length_from_sample_length, speaker_to_target, ) -from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType diff --git a/nemo/collections/asr/models/multitalker_asr_models.py b/nemo/collections/asr/models/multitalker_asr_models.py index 887736d88510..9422a21a935d 100644 --- a/nemo/collections/asr/models/multitalker_asr_models.py +++ b/nemo/collections/asr/models/multitalker_asr_models.py @@ -13,18 +13,16 @@ # limitations under the License. # import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch -import torch.nn.functional as F -from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from omegaconf import DictConfig, open_dict from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text_lhotse_speaker import LhotseSpeechToTextSpkBpeDataset from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionReturnType +from nemo.collections.asr.parts.mixins import TranscribeConfig from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin -from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py index afbc74e60936..7bfc336d63b1 100644 --- a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from abc import ABC +from typing import Optional import torch import torch.nn as nn diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index dd1c95218013..e75344794a4b 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -11,250 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import json import logging import math -import os import random -import re -import time from collections import defaultdict from copy import deepcopy -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional, Union -import numpy as np -import soundfile -import soundfile as sf import torch.utils.data from cytoolz import groupby -from lhotse import AudioSource, Recording, SupervisionSegment, SupervisionSet, dill_enabled -from lhotse.cut import Cut, CutSet, MixedCut, MixTrack, MonoCut -from lhotse.cut.set import mix -from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator -from lhotse.utils import compute_num_samples, ifnone, uuid4 -from omegaconf import OmegaConf -from scipy.stats import norm -from tqdm import tqdm - -from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator -from nemo.collections.asr.parts.utils.data_simulation_utils import read_rir_manifest - - -@dataclass -class SessionConfig: - num_speakers: int = 1 - num_sessions: int = 1 - session_length: int = 15 - session_length_range: List[int] = field(default_factory=lambda: [10, 40]) - - -@dataclass -class SessionParams: - max_audio_read_sec: float = 20.0 - sentence_length_params: List[float] = field(default_factory=lambda: [0.4, 0.05]) - dominance_var: float = 0.11 - min_dominance: float = 0.05 - turn_prob: float = 0.875 - min_turn_prob: float = 0.5 - mean_silence: float = 0.15 - mean_silence_var: float = 0.01 - per_silence_var: int = 900 - per_silence_min: float = 0.0 - per_silence_max: float = -1.0 - mean_overlap: float = 0.1 - mean_overlap_var: float = 0.01 - per_overlap_var: int = 900 - per_overlap_min: float = 0.0 - per_overlap_max: float = -1.0 - start_window: bool = True - window_type: str = "hamming" - window_size: float = 0.02 - start_buffer: float = 0.0 - split_buffer: float = 0.01 - release_buffer: float = 0.0 - normalize: bool = True - normalization_type: str = "equal" - normalization_var: float = 0.1 - min_volume: float = 0.75 - max_volume: float = 1.25 - end_buffer: float = 0.5 - random_offset: bool = True - - -@dataclass -class OutputConfig: - output_dir: str = "" - output_filename: str = "multispeaker_session" - overwrite_output: bool = True - output_precision: int = 3 - - -@dataclass -class BackgroundNoise: - add_bg: bool = True - background_manifest: Optional[str] = None - rir_manifest: Optional[str] = None - num_noise_files: int = 10 - snr: int = 60 - snr_min: Optional[float] = None - snr_max: Optional[float] = None - - -@dataclass -class SegmentAugmentor: - add_seg_aug: bool = False - gain_prob: float = 0.5 - min_gain_dbfs: float = -10.0 - max_gain_dbfs: float = 10.0 - - -@dataclass -class SessionAugmentor: - add_sess_aug: bool = False - white_noise_prob: float = 1.0 - min_white_noise_level: int = -90 - max_white_noise_level: int = -46 - - -@dataclass -class SpeakerEnforcement: - enforce_num_speakers: bool = True - enforce_time: List[float] = field(default_factory=lambda: [0.25, 0.75]) - - -@dataclass -class SegmentManifest: - window: float = 0.5 - shift: float = 0.25 - step_count: int = 50 - deci: int = 3 - - -@dataclass -class RIRGeneration: - use_rir: bool = False - toolkit: str = "pyroomacoustics" - room_sz: List[List[int]] = field(default_factory=lambda: [[2, 3], [2, 3], [2, 3]]) - pos_src: List[List[List[float]]] = field(default_factory=lambda: [[[0.5, 1.5]] * 3] * 4) - noise_src_pos: List[float] = field(default_factory=lambda: [1.5, 1.5, 2]) - num_channels: int = 2 - pos_rcv: List[List[List[float]]] = field(default_factory=lambda: [[[0.5, 1.5]] * 3] * 2) - orV_rcv: Optional[List[List[float]]] = None - mic_pattern: str = "omni" - abs_weights: List[float] = field(default_factory=lambda: [0.9] * 6) - T60: float = 0.1 - att_diff: float = 15.0 - att_max: float = 60.0 - - -@dataclass -class DataSimConfig: - """Configuration for data simulation.""" - - manifest_filepath: str = "" - sr: int = 16000 - random_seed: int = 42 - multiprocessing_chunksize: int = 10000 - session_config: SessionConfig = field(default_factory=SessionConfig) - session_params: SessionParams = field(default_factory=SessionParams) - outputs: OutputConfig = field(default_factory=OutputConfig) - background_noise: BackgroundNoise = field(default_factory=BackgroundNoise) - background_manifest: str = "" - segment_augmentor: SegmentAugmentor = field(default_factory=SegmentAugmentor) - session_augmentor: SessionAugmentor = field(default_factory=SessionAugmentor) - speaker_enforcement: SpeakerEnforcement = field(default_factory=SpeakerEnforcement) - segment_manifest: SegmentManifest = field(default_factory=SegmentManifest) - rir_generation: RIRGeneration = field(default_factory=RIRGeneration) - - -@dataclass -class MultiSpeakerSimulatorConfig: - data_simulator: DataSimConfig = field(default_factory=DataSimConfig) - - -class Segment: - def __init__(self, start, end, speaker_id, text): - self.start = start - self.end = end - self.speaker_id = speaker_id - self.text = text - - def __str__(self): - return f"Segment(start={self.start}, end={self.end}, speaker_id={self.speaker_id}, text=\"{self.text}\")" - - -class SegList: - def __init__(self, segments: List[Segment] = None, seglst_filepath: str = None): - if segments is not None: - self.segments = segments - elif seglst_filepath is not None: - self._load_seglst(seglst_filepath) - else: - raise ValueError("Either segments or seglst_filepath must be provided") - - def _load_seglst(self, seglst_filepath: str | list[str]): - if isinstance(seglst_filepath, str): - with open(seglst_filepath, 'r', encoding='utf-8') as f: - seglst = json.load(f) - self.segments = [ - Segment(seg['start_time'], seg['end_time'], seg['speaker'], seg['words']) for seg in seglst - ] - elif isinstance(seglst_filepath, list): - for seglst_file in seglst_filepath: - with open(seglst_file, 'r', encoding='utf-8') as f: - seglst = json.load(f) - segments = [ - Segment(seg['start_time'], seg['end_time'], seg['speaker'], seg['words']) for seg in seglst - ] - self.segments.extend(segments) - else: - raise ValueError("seglst_filepath must be a string or a list of strings") - self.sort() - - def __len__(self): - return len(self.segments) - - def __getitem__(self, idx): - return self.segments[idx] - - def __iter__(self): - return iter(self.segments) - - def sort(self): - self.segments.sort(key=lambda x: x.start) - - def get_segments(self, min_duration: float, max_duration: float): - - duration = random.uniform(min_duration, max_duration) - - first_segment_idx = random.randint(0, len(self) - 1) - segments = [self[first_segment_idx]] - - offset = self[first_segment_idx].start - for i in range(first_segment_idx + 1, len(self)): - if self[i].end - offset <= duration: - segments.append(self[i]) - else: - break - - return segments - - def get_text_from_segments( - self, segments: list[Segment], speaker_token_style='<|spltoken*|>', speaker_token_position='sot' - ): - text = '' - speakers = set([segment.speaker_id for segment in segments]) - speaker2start = { - spk_id: min(segment.start for segment in segments if segment.speaker_id == spk_id) for spk_id in speakers - } - sorted_speakers = sorted(speakers, key=lambda x: speaker2start[x]) - speaker2token = {spk: speaker_token_style.replace('*', str(i)) for i, spk in enumerate(sorted_speakers)} - for segment in segments: - text += f'{speaker2token[segment.speaker_id]} ' - text += segment.text - return text.strip() +from lhotse import AudioSource, Recording, SupervisionSegment, SupervisionSet +from lhotse.cut import Cut, MixedCut, MixTrack, MonoCut +from lhotse.lazy import LazyJsonlIterator +from lhotse.utils import compute_num_samples, uuid4 def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index e44bf63e11b9..834198a3dc93 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -521,18 +521,13 @@ def read_rir_manifest(rir_manifest: str): """ Read the rir manifest file and sample the rir manifest. """ - # if isinstance(rir_manifest, str): - # rir_manifest_list = [rir_manifest] - # elif isinstance(rir_manifest, list): - # rir_manifest_list = rir_manifest + rir_manifest_list = [rir_manifest] rir_loaded_list = [] for manifest_file in rir_manifest_list: - # try: if os.path.exists(manifest_file): rir_loaded_list.extend(read_manifest(manifest_file)) - # except: - # import ipdb; ipdb.set_trace() + return rir_loaded_list diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 68e7a4d48434..e585dd66f947 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -736,7 +736,6 @@ def get_frame_and_words_online( new_tokens = list(itertools.chain(*new_token_group)) frame_inds_seq = (torch.tensor(previous_hypothesis.timestamp) + offset).tolist() frame_inds_seq = fix_frame_time_step(self.cfg, new_tokens, new_words, frame_inds_seq) - min_len = min(len(new_words), len(frame_inds_seq)) word_and_ts_seq['uniq_id'] = uniq_id min_len = min(len(new_words), len(frame_inds_seq)) From b1be3a04435431f90aad4cd3a5ad7efb8d9199a4 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Thu, 9 Oct 2025 14:50:13 -0700 Subject: [PATCH 04/29] adding list_available_models() method Signed-off-by: Weiqing Wang --- .../asr/data/audio_to_text_lhotse_speaker.py | 5 +++-- .../asr/models/multitalker_asr_models.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 04d43a5ecf93..6cc576b27b46 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -20,6 +20,7 @@ from lhotse.dataset.collation import collate_matrices, collate_vectors from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( speaker_to_target, ) @@ -47,10 +48,10 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'bg_spk_targets': NeuralType(('B', 'T'), LabelsType()), } - def __init__(self, cfg, tokenizer): + def __init__(self, cfg, tokenizer: TokenizerSpec): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) - self.load_audio = AudioSamples(fault_tolerant=True, num_workers=8) + self.load_audio = AudioSamples(fault_tolerant=True) self.cfg = cfg self.num_speakers = self.cfg.get('num_speakers', 4) self.num_sample_per_mel_frame = self.cfg.get('num_sample_per_mel_frame', 160) diff --git a/nemo/collections/asr/models/multitalker_asr_models.py b/nemo/collections/asr/models/multitalker_asr_models.py index 9422a21a935d..b93c017422c2 100644 --- a/nemo/collections/asr/models/multitalker_asr_models.py +++ b/nemo/collections/asr/models/multitalker_asr_models.py @@ -13,7 +13,7 @@ # limitations under the License. # import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch from omegaconf import DictConfig, open_dict @@ -24,6 +24,7 @@ from nemo.collections.asr.parts.mixins import TranscribeConfig from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes.common import PretrainedModelInfo class EncDecMultiTalkerRNNTBPEModel(EncDecRNNTBPEModel, SpeakerKernelMixin): @@ -34,6 +35,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Initialize speaker kernel functionality from mixin self._init_speaker_kernel_config(cfg) + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + return results + def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): # Use open_dict to allow dynamic key addition @@ -50,6 +63,8 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): tokenizer=self.tokenizer, ), ) + else: + raise ValueError("Only lhotse dataloader is supported for multitalker models") def training_step(self, batch, batch_nb): """Training step with speaker targets.""" From 51243c44d4b51d850003777cc50e1b58f32b89e5 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Thu, 9 Oct 2025 21:50:56 +0000 Subject: [PATCH 05/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- nemo/collections/asr/data/audio_to_text_lhotse_speaker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 6cc576b27b46..82fe57b1fb2e 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -20,10 +20,8 @@ from lhotse.dataset.collation import collate_matrices, collate_vectors from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import speaker_to_target from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - speaker_to_target, -) from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType From 3fe88c7dd2c0eb6d65c99156d65e5a6ef885cfca Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Thu, 9 Oct 2025 15:23:09 -0700 Subject: [PATCH 06/29] solving bg_spk_targets default values Signed-off-by: Weiqing Wang --- .../asr/parts/mixins/multitalker_asr_mixins.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py index 7bfc336d63b1..77e3a2eb820c 100644 --- a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -152,7 +152,7 @@ def hook_fn(module, args, kwargs): # residual connection x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) x = x + x_bg_spk kwargs['x'] = x elif args: @@ -162,7 +162,7 @@ def hook_fn(module, args, kwargs): # residual connection x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) x = x + x_bg_spk args = (x, *rest) @@ -195,7 +195,7 @@ def hook_fn(module, input, output): x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets)) + x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) x = x + x_bg_spk if isinstance(output, tuple): @@ -238,30 +238,30 @@ def clear_speaker_targets(self): if self.add_bg_spk_kernel: self.bg_spk_targets = None - def solve_length_mismatch(self, x: torch.Tensor, mask: torch.Tensor): + def solve_length_mismatch(self, x: torch.Tensor, mask: torch.Tensor, default_value: float = 1.0): """ Solve length mismatch between x and mask. """ if mask is None: - mask = torch.ones_like(x[:, :, 0]) + mask = torch.ones_like(x[:, :, 0]) * default_value logging.warning( f"Mask is None, triggering single speaker mode and assigning all ones with shape: {mask.shape}" ) if mask.shape[1] < x.shape[1]: # pad zero to the left - mask = torch.nn.functional.pad(mask, (x.shape[1] - mask.shape[1], 0), mode='constant', value=1) + mask = torch.nn.functional.pad(mask, (x.shape[1] - mask.shape[1], 0), mode='constant', value=default_value) if mask.shape[1] > x.shape[1]: mask = mask[:, -x.shape[1] :] return mask - def mask_with_speaker_targets(self, x: torch.Tensor, spk_targets: torch.Tensor): + def mask_with_speaker_targets(self, x: torch.Tensor, spk_targets: torch.Tensor, default_value: float = 1.0): """ Mask the input with speaker targets. """ - mask = self.solve_length_mismatch(x, spk_targets) + mask = self.solve_length_mismatch(x, spk_targets, default_value) x_spk = x * mask.unsqueeze(2) return x_spk From 009c83aa277f8b64833c3e06e48dbe5f3f9df253 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Thu, 9 Oct 2025 22:24:56 +0000 Subject: [PATCH 07/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- .../asr/parts/mixins/multitalker_asr_mixins.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py index 77e3a2eb820c..6a6f1ee3d73a 100644 --- a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -152,7 +152,9 @@ def hook_fn(module, args, kwargs): # residual connection x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) + x_bg_spk = self.bg_spk_kernels[layer_idx]( + self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0) + ) x = x + x_bg_spk kwargs['x'] = x elif args: @@ -162,7 +164,9 @@ def hook_fn(module, args, kwargs): # residual connection x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) + x_bg_spk = self.bg_spk_kernels[layer_idx]( + self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0) + ) x = x + x_bg_spk args = (x, *rest) @@ -195,7 +199,9 @@ def hook_fn(module, input, output): x = x + x_spk if self.add_bg_spk_kernel: - x_bg_spk = self.bg_spk_kernels[layer_idx](self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0)) + x_bg_spk = self.bg_spk_kernels[layer_idx]( + self.mask_with_speaker_targets(x, self.bg_spk_targets, default_value=0.0) + ) x = x + x_bg_spk if isinstance(output, tuple): From 79d8d801005fb6e802f0a21f264745c289e13d9d Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 9 Oct 2025 17:58:52 -0700 Subject: [PATCH 08/29] Fixing Flake and Lint issues Signed-off-by: taejinp --- .../collections/asr/parts/mixins/streaming.py | 1 - .../asr/parts/utils/data_simulation_utils.py | 3 +- .../asr/parts/utils/diarization_utils.py | 84 +++++++++---------- 3 files changed, 39 insertions(+), 49 deletions(-) diff --git a/nemo/collections/asr/parts/mixins/streaming.py b/nemo/collections/asr/parts/mixins/streaming.py index 04056540ac04..59bcb27cd398 100644 --- a/nemo/collections/asr/parts/mixins/streaming.py +++ b/nemo/collections/asr/parts/mixins/streaming.py @@ -14,7 +14,6 @@ from abc import ABC, abstractmethod -import torch class StreamingEncoder(ABC): diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 834198a3dc93..260c1d3f2e45 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -16,7 +16,7 @@ import os import shutil from collections import defaultdict -from typing import IO, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -755,7 +755,6 @@ def create_new_rttm_entry( silence_length > 2 * self._params.data_simulator.session_params.split_buffer ): # split utterance on silence new_end = start + alignments[i - 1] - silence_duration = alignments[i] - alignments[i - 1] # new_end = start + alignments[i - 1] + self._params.data_simulator.session_params.split_buffer diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index 6d3ac54fdaa4..560ddbd4236f 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -16,6 +16,7 @@ import csv import json import os +import string from collections import OrderedDict as od from collections import defaultdict from datetime import datetime @@ -31,7 +32,6 @@ audio_rttm_map, generate_diarization_output_lines, get_uniqname_from_filepath, - labels_to_pyannote_object, labels_to_rttmfile, rttm_to_labels, write_rttm2manifest, @@ -237,38 +237,6 @@ def convert_word_dict_seq_to_ctm( ctm_lines.append(ctm_line_str) return ctm_lines - def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_chars_in_line: int = 90) -> str: - """ - Break the lines in the transcript. - - Args: - string_out (str): - Input transcript with speaker labels - max_chars_in_line (int): - Maximum characters in each line - - Returns: - return_string_out (str): - String variable containing line breaking - """ - color_str_len = len('\033[1;00m') if self.params['colored_text'] else 0 - split_string_out = string_out.split('\n') - return_string_out = [] - for org_chunk in split_string_out: - buffer = [] - if len(org_chunk) - color_str_len > max_chars_in_line: - color_str = org_chunk[:color_str_len] if color_str_len > 0 else '' - for i in range(color_str_len, len(org_chunk), max_chars_in_line): - trans_str = org_chunk[i : i + max_chars_in_line] - if len(trans_str.strip()) > 0: - c_trans_str = color_str + trans_str - buffer.append(c_trans_str) - return_string_out.extend(buffer) - else: - return_string_out.append(org_chunk) - return_string_out = '\n'.join(return_string_out) - return return_string_out - def get_total_result_dict( der_results: Dict[str, Dict[str, float]], @@ -866,21 +834,11 @@ def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) - ref_labels = generate_diarization_output_lines( - speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers) - ) - hyp_labels = generate_diarization_output_lines( - speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers) - ) - reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) - hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) - for idx, speaker in enumerate(ref_speakers): ref_speaker_words[idx] += ref_speaker_word[idx] for idx, speaker in enumerate(hyp_speakers): hyp_speaker_words[idx] += hyp_speaker_word[idx] - der_instance = der_metric(reference, hypothesis) # Normalize the text for spk_idx in range(len(hyp_speaker_words)): hyp_speaker_words[spk_idx] = ( @@ -1317,7 +1275,7 @@ def get_transcript_with_speaker_labels( if self.fix_word_ts_with_VAD: if self.frame_VAD == {}: logging.warning( - f"VAD timestamps are not provided. Fixing word timestamps without VAD. Please check the hydra configurations." + "VAD timestamps are not provided. Fixing word timestamps without VAD. Please check the hydra configurations." ) word_ts_refined = self._compensate_word_ts_list(self.audio_file_list, word_ts_hyp) else: @@ -1468,7 +1426,7 @@ def _make_json_output( ] } """ - logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ") + logging.info(f"Creating results for Session: {uniq_id}") session_trans_dict, gecko_dict, audacity_label_words, sentences = get_session_trans_dict( uniq_id, word_dict_seq_list, diar_labels ) @@ -1740,7 +1698,7 @@ def _write_and_log( # print the sentences in the .txt output string_out = print_sentences(sentences, color_palette=self.color_palette, params=self.params) if self.params['break_lines']: - string_out = break_transcript_lines(string_out, params=self.params) + string_out = self.break_transcript_lines(string_out, params=self.params) session_trans_dict["status"] = "success" ctm_lines_list = convert_word_dict_seq_to_ctm(session_trans_dict['words']) @@ -1751,6 +1709,40 @@ def _write_and_log( write_txt(f'{self.root_path}/pred_rttms/{uniq_id}.txt', string_out.strip()) write_txt(f'{self.root_path}/pred_rttms/{uniq_id}.w.label', '\n'.join(audacity_label_words)) + def break_transcript_lines(self, string_out: str, params: Dict[str, str], max_chars_in_line: int = 90) -> str: + """ + Break the lines in the transcript. + + Args: + string_out (str): + Input transcript with speaker labels + params (dict): + Parameters dictionary + max_chars_in_line (int): + Maximum characters in each line + + Returns: + return_string_out (str): + String variable containing line breaking + """ + color_str_len = len('\033[1;00m') if self.params['colored_text'] else 0 + split_string_out = string_out.split('\n') + return_string_out = [] + for org_chunk in split_string_out: + buffer = [] + if len(org_chunk) - color_str_len > max_chars_in_line: + color_str = org_chunk[:color_str_len] if color_str_len > 0 else '' + for i in range(color_str_len, len(org_chunk), max_chars_in_line): + trans_str = org_chunk[i : i + max_chars_in_line] + if len(trans_str.strip()) > 0: + c_trans_str = color_str + trans_str + buffer.append(c_trans_str) + return_string_out.extend(buffer) + else: + return_string_out.append(org_chunk) + return_string_out = '\n'.join(return_string_out) + return return_string_out + @staticmethod def print_errors(der_results: Dict[str, Dict[str, float]], wer_results: Dict[str, Dict[str, float]]): """ From 688359fa933a670255e6fd74b49a008d0760140c Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 10 Oct 2025 00:59:44 +0000 Subject: [PATCH 09/29] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/parts/mixins/streaming.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/parts/mixins/streaming.py b/nemo/collections/asr/parts/mixins/streaming.py index 59bcb27cd398..2db68da5a3d2 100644 --- a/nemo/collections/asr/parts/mixins/streaming.py +++ b/nemo/collections/asr/parts/mixins/streaming.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod - class StreamingEncoder(ABC): @abstractmethod def setup_streaming_params( From 96a7d02c77fd9db21949d4e917ff1405d795b825 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Thu, 9 Oct 2025 18:15:28 -0700 Subject: [PATCH 10/29] Solving flake8 F401 issue Signed-off-by: Weiqing Wang --- nemo/collections/asr/models/__init__.py | 61 +++++++++++++++++++ nemo/collections/asr/parts/mixins/__init__.py | 13 ++++ 2 files changed, 74 insertions(+) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index ee277158c9a0..6ea11ea37f03 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -43,3 +43,64 @@ SpeechEncDecSelfSupervisedModel, ) from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE + +from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.models.classification_models import ( + ClassificationInferConfig, + EncDecClassificationModel, + EncDecFrameClassificationModel, +) +from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer +from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.models.k2_sequence_models import ( + EncDecK2RnntSeqModel, + EncDecK2RnntSeqModelBPE, + EncDecK2SeqModel, + EncDecK2SeqModelBPE, +) +from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel +from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel, NeuralDiarizer +from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel +from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel +from nemo.collections.asr.models.ssl_models import ( + EncDecDenoiseMaskedTokenPredModel, + EncDecMaskedTokenPredModel, + SpeechEncDecSelfSupervisedModel, +) +from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE + +__all__ = [ + 'ASRModel', + 'ClassificationInferConfig', + 'ClusteringDiarizer', + 'EncDecCTCModel', + 'EncDecCTCModelBPE', + 'EncDecClassificationModel', + 'EncDecDenoiseMaskedTokenPredModel', + 'EncDecDiarLabelModel', + 'EncDecFrameClassificationModel', + 'EncDecHybridRNNTCTCBPEModel', + 'EncDecHybridRNNTCTCModel', + 'EncDecK2RnntSeqModel', + 'EncDecK2RnntSeqModelBPE', + 'EncDecK2SeqModel', + 'EncDecK2SeqModelBPE', + 'EncDecMaskedTokenPredModel', + 'EncDecMultiTaskModel', + 'EncDecMultiTalkerRNNTBPEModel', + 'EncDecRNNTBPEModel', + 'EncDecRNNTModel', + 'EncDecSpeakerLabelModel', + 'EncDecTransfModelBPE', + 'NeuralDiarizer', + 'SLUIntentSlotBPEModel', + 'SortformerEncLabelModel', + 'SpeechEncDecSelfSupervisedModel', +] \ No newline at end of file diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py index 3c4f837dbf5a..74ca26ecf265 100644 --- a/nemo/collections/asr/parts/mixins/__init__.py +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -27,3 +27,16 @@ TranscriptionMixin, TranscriptionReturnType, ) + +__all__ = [ + 'ASRAdapterModelMixin', + 'ASRBPEMixin', + 'ASRModuleMixin', + 'ASRTranscriptionMixin', + 'DiarizationMixin', + 'InterCTCMixin', + 'SpeakerKernelMixin', + 'TranscribeConfig', + 'TranscriptionMixin', + 'TranscriptionReturnType', +] \ No newline at end of file From 22e579e5591a6ca522ab983a105a4b33ca5a6da5 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Fri, 10 Oct 2025 01:16:35 +0000 Subject: [PATCH 11/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- nemo/collections/asr/models/__init__.py | 34 +------------------ nemo/collections/asr/parts/mixins/__init__.py | 2 +- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 6ea11ea37f03..4680fbaa4c28 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -44,38 +44,6 @@ ) from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE -from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel -from nemo.collections.asr.models.asr_model import ASRModel -from nemo.collections.asr.models.classification_models import ( - ClassificationInferConfig, - EncDecClassificationModel, - EncDecFrameClassificationModel, -) -from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer -from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE -from nemo.collections.asr.models.ctc_models import EncDecCTCModel -from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel -from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel -from nemo.collections.asr.models.k2_sequence_models import ( - EncDecK2RnntSeqModel, - EncDecK2RnntSeqModelBPE, - EncDecK2SeqModel, - EncDecK2SeqModelBPE, -) -from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel -from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel, NeuralDiarizer -from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel -from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel -from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel -from nemo.collections.asr.models.ssl_models import ( - EncDecDenoiseMaskedTokenPredModel, - EncDecMaskedTokenPredModel, - SpeechEncDecSelfSupervisedModel, -) -from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE - __all__ = [ 'ASRModel', 'ClassificationInferConfig', @@ -103,4 +71,4 @@ 'SLUIntentSlotBPEModel', 'SortformerEncLabelModel', 'SpeechEncDecSelfSupervisedModel', -] \ No newline at end of file +] diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py index 74ca26ecf265..c8bfd1503454 100644 --- a/nemo/collections/asr/parts/mixins/__init__.py +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -39,4 +39,4 @@ 'TranscribeConfig', 'TranscriptionMixin', 'TranscriptionReturnType', -] \ No newline at end of file +] From 303bc490cbdd5c83a0d5f293f3d6360fb33dbae3 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 9 Oct 2025 18:25:57 -0700 Subject: [PATCH 12/29] Fixing linting issues in diarization_utils.py Signed-off-by: taejinp --- .../asr/parts/utils/diarization_utils.py | 177 ++---------------- 1 file changed, 15 insertions(+), 162 deletions(-) diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index 560ddbd4236f..662200a34802 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -30,7 +30,6 @@ from nemo.collections.asr.models import ClusteringDiarizer from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, - generate_diarization_output_lines, get_uniqname_from_filepath, labels_to_rttmfile, rttm_to_labels, @@ -311,61 +310,6 @@ def get_num_of_spk_from_labels(labels: List[str]) -> int: spk_set = [x.split(' ')[-1].strip() for x in labels] return len(set(spk_set)) - -def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_start_time=False): - """ - Read a seglst file and return the speaker & text information dictionary. - - Args: - seglst_filepath: path to the seglst file - seglst format: - [ - { - "session_id": "Bed008", - "words": "alright so i'm i should read all of these numbers", - "speaker": "me045", - "start_time": "53.814", - "end_time": "56.753" - } - ] - round_digits (int): number of digits to round the timestamps - return_rttm (bool): Whether to return RTTM lines - - Returns: - seglst_dict (dict): - A dictionary containing speaker and text information for each segment. - rttm_lines (list): - A list containing RTTM lines. - """ - rttm_lines = [] - seglst = [] - with open(seglst_filepath, 'r') as f: - seglst_lines = json.loads(f.read()) - - for idx, line in enumerate(seglst_lines): - spk, start, end = line['speaker'], float(line['start_time']), float(line['end_time']) - dur = round(end - start, round_digits) - - if return_rttm: - rttm_line_str = f'SPEAKER {line["session_id"]} 1 {start:.3f} {end-start:.3f} {spk} ' - rttm_lines.append(rttm_line_str) - seglst.append( - { - 'session_id': line['session_id'], - 'speaker': spk, - 'words': line['words'], - 'start_time': start, - 'end_time': end, - 'duration': dur, - } - ) - if sort_by_start_time: - seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) - if return_rttm: - return seglst, rttm_lines - return seglst - - def convert_seglst(seglst, all_speakers): ''' convert the seglst to a format that can be used for scoring @@ -394,7 +338,21 @@ def convert_seglst(seglst, all_speakers): return timestamps, words -def get_session_trans_dict(uniq_id, word_dict_seq_list, diar_labels): +def get_session_trans_dict(uniq_id: str, word_dict_seq_list: List[Dict[str, float]], diar_labels: List[str]): + """ + Get the session transcription dictionary. + + Args: + uniq_id (str): the unique id of the session + word_dict_seq_list (list): list of word dictionaries + diar_labels (list): list of diarization labels + + Returns: + session_trans_dict (dict): the session transcription dictionary + gecko_dict (dict): the gecko dictionary + audacity_label_words (list): the audacity label words + sentences (list): the sentences + """ n_spk = get_num_of_spk_from_labels(diar_labels) session_trans_dict = init_session_trans_dict(uniq_id=uniq_id, n_spk=n_spk) gecko_dict = init_session_gecko_dict() @@ -563,34 +521,6 @@ def read_seglst(seglst_filepath, round_digits=3, return_rttm=False, sort_by_star return seglst -def convert_seglst(seglst, all_speakers): - ''' - convert the seglst to a format that can be used for scoring - - Args: - seglst (list): list of seglst dictionaries - all_speakers (list): list of all active speakers - Returns: - timestamps: (list of list) - [ - [[st1, et1], [st2, et2]], # timestamps list for speaker 1 - [[st1, et1], ...], # timestamps list for speaker 2 - ...] - words (list[[s1], [s2], [s3], [s4]]): list of words for each speaker 1 to 4 - ''' - - timestamps = [[] for _ in all_speakers] - words = ['' for _ in all_speakers] - - spk2id = {spk: idx for idx, spk in enumerate(all_speakers)} - seglst = sorted(seglst, key=lambda x: (x['start_time'], x['end_time'])) - for seg in seglst: - timestamps[spk2id[seg['speaker']]].append((seg['start_time'], seg['end_time'])) - words[spk2id[seg['speaker']]] += seg['words'] + ' ' - - return timestamps, words - - def chunk_seglst(seglst: List[Dict], chunk_size: float = 10.0): ''' Get chunked timestamps and words for each speaker @@ -650,83 +580,6 @@ def chunk_seglst(seglst: List[Dict], chunk_size: float = 10.0): return chunk_id2timestamps, speakers, session_id - -# def streaming_evaluation( -# ref_seglst: List[Dict], -# ref_rttm_labels: List[str], -# hyp_seglst: List[Dict], -# collar: float = 0.25, -# ignore_overlap: bool = False, -# verbose: bool = True, -# chunk_size: float = 10.0, -# ): -# """ -# Perform streaming evaluation of diarization and ASR for one session - -# Args: -# ref_seglst (list): list of reference seglst dictionaries -# hyp_seglst (list): list of hypothesis seglst dictionaries -# collar (float): collar for DER calculation -# ignore_overlap (bool): whether to ignore overlapping segments -# verbose (bool): whether to print verbose output -# chunk_size (float): how frequently to chunk and evaluate the session -# """ -# max_duration = max([seg['end_time'] for seg in ref_seglst + hyp_seglst]) -# max_idx = int(max_duration // chunk_size) + 1 - -# chunked_ref_seglst, ref_speakers, ref_session_id = chunk_seglst(ref_seglst, chunk_size=chunk_size) -# chunked_hyp_seglst, hyp_speakers, hyp_session_id = chunk_seglst(hyp_seglst, chunk_size=chunk_size) - -# if ref_session_id is None: -# ref_session_id = hyp_session_id - -# assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" - -# # Only care about the sessions in reference only -# session_id = ref_session_id -# ref_speaker_words = defaultdict(list) -# hyp_speaker_words = defaultdict(list) - -# der_metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) -# cpwer_metric = calculate_session_cpWER -# der_list, cpwer_list = [], [] -# for chunk_idx in range(max_idx): -# ref_seglst = chunked_ref_seglst[chunk_idx] -# hyp_seglst = chunked_hyp_seglst[chunk_idx] - -# if len(ref_speaker_words) == 0: -# ref_speaker_words = ['' for _ in ref_speakers] -# if len(hyp_speaker_words) == 0: -# hyp_speaker_words = ['' for _ in hyp_speakers] -# if self.ref_rttm_labels is not None: -# ref_labels = self.ref_rttm_labels -# else: -# ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) -# ref_labels = generate_diarization_output_lines(speaker_timestamps=ref_speaker_timestamps, model_spk_num=len(ref_speakers)) -# hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) - -# hyp_labels = generate_diarization_output_lines(speaker_timestamps=hyp_speaker_timestamps, model_spk_num=len(hyp_speakers)) -# reference = labels_to_pyannote_object(ref_labels, uniq_name=session_id) -# hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=session_id) - -# for idx, speaker in enumerate(ref_speakers): -# ref_speaker_words[idx] += ref_speaker_word[idx] -# for idx, speaker in enumerate(hyp_speakers): -# hyp_speaker_words[idx] += hyp_speaker_word[idx] - -# der_met = der_metric(reference, hypothesis) -# cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) - -# if verbose: -# logging.info(f"Session ID: {session_id} Chunk ID: {chunk_idx} from {chunk_idx*chunk_size}s to {(chunk_idx+1)*chunk_size}s") -# logging.info(f"DER: {abs(der_metric)*100:.2f}%, cpWER: {cpWER*100:.2f}%") - -# der_list.append(abs(der_metric) * 100) -# cpwer_list.append(cpWER) - -# return der_list, cpwer_list - - class OnlineEvaluation: """ A class designed for performing online evaluation of diarization and ASR. From d879d486259de67fde56d0e3dac0cbb9008240a4 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 10 Oct 2025 01:26:51 +0000 Subject: [PATCH 13/29] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/parts/utils/diarization_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index 662200a34802..a99f86d6a634 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -310,6 +310,7 @@ def get_num_of_spk_from_labels(labels: List[str]) -> int: spk_set = [x.split(' ')[-1].strip() for x in labels] return len(set(spk_set)) + def convert_seglst(seglst, all_speakers): ''' convert the seglst to a format that can be used for scoring @@ -580,6 +581,7 @@ def chunk_seglst(seglst: List[Dict], chunk_size: float = 10.0): return chunk_id2timestamps, speakers, session_id + class OnlineEvaluation: """ A class designed for performing online evaluation of diarization and ASR. From c72837e7359894316a4e2fd322c1422acbc071b7 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Fri, 10 Oct 2025 17:59:23 -0700 Subject: [PATCH 14/29] Solving unitest issue Signed-off-by: Weiqing Wang --- ...eech_to_text_multitalker_streaming_infer.py | 4 +++- .../asr/data/audio_to_diar_label_lhotse.py | 11 +++++++---- .../asr/data/audio_to_text_lhotse_speaker.py | 2 +- .../asr/parts/utils/asr_multispeaker_utils.py | 15 +++------------ .../asr/parts/utils/data_simulation_utils.py | 4 ---- .../parts/utils/multispk_transcribe_utils.py | 18 ++++++++++++++++-- nemo/collections/common/data/lhotse/cutset.py | 4 +++- .../utils/test_data_simul_utils_speaker.py | 4 +++- 8 files changed, 36 insertions(+), 26 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 8ca29f34490f..2ac2449b396f 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -40,6 +40,9 @@ class DiarizationConfig: diar_pretrained_name: Optional[str] = None # Name of a pretrained model max_num_of_spks: Optional[int] = 4 parallel_speaker_strategy: bool = True + masked_asr: bool = True + mask_preencode: bool = False + single_speaker_model: bool = False # General configs session_len_sec: float = -1 # End-to-end diarization session length in seconds @@ -102,7 +105,6 @@ class DiarizationConfig: spk_supervision: str = "diar" # ["diar", "rttm"] binary_diar_preds: bool = False - def format_time(seconds): minutes = math.floor(seconds / 60) sec = seconds % 60 diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index d826fb13dc3a..ea939ba1e610 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -55,7 +55,6 @@ def __init__(self, cfg): self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) ) # 160 samples for every 1ms by default self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) - self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) @@ -63,14 +62,18 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: for cut in cuts: speaker_activity = speaker_to_target( a_cut=cut, - num_speakers=self.num_speakers, num_sample_per_mel_frame=self.num_sample_per_mel_frame, num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, - spk_tar_all_zero=self.spk_tar_all_zero, boundary_segments=True, ) speaker_activities.append(speaker_activity) - targets = collate_matrices(speaker_activities).to(audio.dtype) + targets = collate_matrices(speaker_activities).to(audio.dtype) # (B, T, N) + + if targets.shape[2] > self.num_speakers: + targets = targets[:, :, :self.num_speakers] + elif targets.shape[2] < self.num_speakers: + targets = torch.nn.functional.pad(targets, (0, self.num_speakers - targets.shape[2]), mode='constant', value=0) + target_lens_list = [] for audio_len in audio_lens: target_fr_len = get_hidden_length_from_sample_length( diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 82fe57b1fb2e..dcd8344f730e 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -70,7 +70,7 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame) for cut in cuts ] spk_targets = collate_matrices(speaker_targets, padding_value=0) - return audio, audio_lens, None, None, spk_targets + return audio, audio_lens, None, None, None, None for idx, cut in enumerate(cuts): diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index e75344794a4b..9385fbe4696e 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -226,7 +226,6 @@ def get_mask_from_segments( speaker_to_idx_map: torch.Tensor, num_speakers: int = 4, feat_per_sec: int = 100, - ignore_num_spk_mismatch: bool = False, ): """ Generate mask matrix from segments list. @@ -238,8 +237,6 @@ def get_mask_from_segments( speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. - Will be removed in the future. Returns: mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). @@ -250,11 +247,7 @@ def get_mask_from_segments( mask = torch.zeros((num_samples, num_speakers)) for rttm_sup in segments: speaker_idx = speaker_to_idx_map[rttm_sup.speaker] - if speaker_idx >= num_speakers: - if ignore_num_spk_mismatch: - continue - else: - raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) ent = min(rttm_sup.end, a_cut.duration) stf = int(stt * feat_per_sec) @@ -328,7 +321,6 @@ def speaker_to_target( boundary_segments: bool = False, soft_label: bool = False, soft_thres: float = 0.5, - ignore_num_spk_mismatch: bool = True, return_text: bool = False, ): ''' @@ -343,7 +335,6 @@ def speaker_to_target( boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. soft_thres (float): the threshold for the soft label, 0.5 by default. - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. return_text (bool): set to True to return the text of the speakers (if it is available), False by default. Returns: @@ -407,7 +398,7 @@ def speaker_to_target( a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame ) frame_mask = get_mask_from_segments( - segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) @@ -428,7 +419,7 @@ def speaker_to_target( def read_seglst(seglst_filepath: str, session_id: Optional[str] = None): """ - Read the seglst file and return a list of segments. + Read the seglst file and return a list of SupervisionSegment. """ with open(seglst_filepath, 'r', encoding='utf-8') as f: seglst = json.load(f) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 260c1d3f2e45..14a853bcb6a3 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -895,9 +895,7 @@ def create_new_ctm_entry( ): # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) - # align1 = round(float(start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - # align2 = round(float(alignments[i+1] - start), self._params.data_simulator.outputs.output_precision) end_time = round(align1 + align2, self._params.data_simulator.outputs.output_precision) text = get_ctm_line( source=session_name, @@ -910,8 +908,6 @@ def create_new_ctm_entry( speaker=speaker_id, output_precision=self._params.data_simulator.outputs.output_precision, ) - # if word == "it" and align1 == 3.169: - # import ipdb; ipdb.set_trace() word_and_ts_list.append((word, align1, end_time)) arr.append((align1, text)) return arr, word_and_ts_list diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index e585dd66f947..56b1e6480d92 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -494,6 +494,7 @@ def __init__( self._masked_asr = cfg.get("masked_asr", True) self._use_mask_preencode = cfg.get("mask_preencode", False) + self._single_speaker_model = cfg.get("single_speaker_model", False) self.instance_manager = MultiTalkerInstanceManager( asr_model=self.asr_model, @@ -1192,6 +1193,19 @@ def perform_parallel_streaming_stt_spk( previous_chunk_preds=new_chunk_preds, diar_streaming_state=new_streaming_state, ) + + # For a session, if no second speaker is detected, + # the spk_targets will be set to all ones in the single speaker mode + if self._single_speaker_model: + if self._max_num_of_spks == 1: + is_single_speaker = [True] * chunk_audio.shape[0] + else: + is_single_speaker = (new_diar_pred_out_stream[:,:,:self._max_num_of_spks] > 0.5).any(1).sum(-1) <= 1.0 + for i in range(chunk_audio.shape[0]): + if is_single_speaker[i]: + new_diar_pred_out_stream[i, :, 0] = 1.0 + new_diar_pred_out_stream[i, :, 1:] = 0.0 + # Step 4: find active speakers diar_chunk_preds = new_diar_pred_out_stream[:, -self._nframes_per_chunk * self._cache_gating_buffer_size :] if self._cache_gating: @@ -1224,8 +1238,8 @@ def perform_parallel_streaming_stt_spk( return # Step 6: - # 1. mask the non-active speakers for masked ASR - # 2. set speaker targets for multitalker ASR + # 1) mask the non-active speakers for masked ASR; or + # 2) set speaker targets for multitalker ASR if self._masked_asr: if self._use_mask_preencode: active_chunk_audio = self.mask_preencode(chunk_audio=active_chunk_audio, mask=active_speaker_targets) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 44164dac191d..25fb9ba812f6 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -27,7 +27,6 @@ from lhotse.serialization import load_yaml from omegaconf import DictConfig, ListConfig, OmegaConf -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator from nemo.collections.common.data.lhotse.nemo_adapters import ( LazyNeMoIterator, LazyNeMoTarredIterator, @@ -818,6 +817,9 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: @data_type_parser("multi_speaker_simulator") def read_multi_speaker_simulator(config: DictConfig) -> tuple[CutSet, bool]: + # Import here to avoid circular dependency + from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator + multi_speaker_cuts = CutSet( MultiSpeakerMixtureGenerator( manifest_filepath=config.manifest_filepath, diff --git a/tests/collections/speaker_tasks/utils/test_data_simul_utils_speaker.py b/tests/collections/speaker_tasks/utils/test_data_simul_utils_speaker.py index 9a27820cdfa1..8c0fc1e622ae 100644 --- a/tests/collections/speaker_tasks/utils/test_data_simul_utils_speaker.py +++ b/tests/collections/speaker_tasks/utils/test_data_simul_utils_speaker.py @@ -410,7 +410,7 @@ def test_create_new_json_entry(self, annotator): def test_create_new_ctm_entry(self, annotator): words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) session_name = 'test_session' - ctm_list = annotator.create_new_ctm_entry( + ctm_list, _ = annotator.create_new_ctm_entry( words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] ) assert ctm_list[0] == ( @@ -424,6 +424,7 @@ def test_create_new_ctm_entry(self, annotator): conf=None, type_of_token='lex', speaker=speaker_id, + output_precision=annotator._params.data_simulator.outputs.output_precision, ), ) assert ctm_list[1] == ( @@ -437,6 +438,7 @@ def test_create_new_ctm_entry(self, annotator): conf=None, type_of_token='lex', speaker=speaker_id, + output_precision=annotator._params.data_simulator.outputs.output_precision, ), ) From 076396782238718b5edc291e4c1f0a6117af41b4 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Sat, 11 Oct 2025 01:00:35 +0000 Subject: [PATCH 15/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- .../speech_to_text_multitalker_streaming_infer.py | 1 + nemo/collections/asr/data/audio_to_diar_label_lhotse.py | 8 +++++--- .../collections/asr/parts/utils/asr_multispeaker_utils.py | 4 +--- .../asr/parts/utils/multispk_transcribe_utils.py | 6 ++++-- nemo/collections/common/data/lhotse/cutset.py | 2 +- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 2ac2449b396f..5928a3f6b62a 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -105,6 +105,7 @@ class DiarizationConfig: spk_supervision: str = "diar" # ["diar", "rttm"] binary_diar_preds: bool = False + def format_time(seconds): minutes = math.floor(seconds / 60) sec = seconds % 60 diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index ea939ba1e610..260fe884db9b 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -67,12 +67,14 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: boundary_segments=True, ) speaker_activities.append(speaker_activity) - targets = collate_matrices(speaker_activities).to(audio.dtype) # (B, T, N) + targets = collate_matrices(speaker_activities).to(audio.dtype) # (B, T, N) if targets.shape[2] > self.num_speakers: - targets = targets[:, :, :self.num_speakers] + targets = targets[:, :, : self.num_speakers] elif targets.shape[2] < self.num_speakers: - targets = torch.nn.functional.pad(targets, (0, self.num_speakers - targets.shape[2]), mode='constant', value=0) + targets = torch.nn.functional.pad( + targets, (0, self.num_speakers - targets.shape[2]), mode='constant', value=0 + ) target_lens_list = [] for audio_len in audio_lens: diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 9385fbe4696e..2c023f7db91b 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -397,9 +397,7 @@ def speaker_to_target( num_samples = get_hidden_length_from_sample_length( a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame ) - frame_mask = get_mask_from_segments( - segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec - ) + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 56b1e6480d92..180ede5d705f 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -1194,13 +1194,15 @@ def perform_parallel_streaming_stt_spk( diar_streaming_state=new_streaming_state, ) - # For a session, if no second speaker is detected, + # For a session, if no second speaker is detected, # the spk_targets will be set to all ones in the single speaker mode if self._single_speaker_model: if self._max_num_of_spks == 1: is_single_speaker = [True] * chunk_audio.shape[0] else: - is_single_speaker = (new_diar_pred_out_stream[:,:,:self._max_num_of_spks] > 0.5).any(1).sum(-1) <= 1.0 + is_single_speaker = (new_diar_pred_out_stream[:, :, : self._max_num_of_spks] > 0.5).any(1).sum( + -1 + ) <= 1.0 for i in range(chunk_audio.shape[0]): if is_single_speaker[i]: new_diar_pred_out_stream[i, :, 0] = 1.0 diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 25fb9ba812f6..4ba644c97002 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -819,7 +819,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: def read_multi_speaker_simulator(config: DictConfig) -> tuple[CutSet, bool]: # Import here to avoid circular dependency from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator - + multi_speaker_cuts = CutSet( MultiSpeakerMixtureGenerator( manifest_filepath=config.manifest_filepath, From fc7d1bcf0b9fe2e1de0d3bed198b07c48ded108e Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Sun, 12 Oct 2025 04:45:26 -0700 Subject: [PATCH 16/29] solving ctm return issue in test function Signed-off-by: Weiqing Wang --- nemo/collections/asr/data/audio_to_text_lhotse_speaker.py | 4 ---- .../collections/asr/parts/mixins/multitalker_asr_mixins.py | 7 +++---- tests/collections/asr/utils/test_data_simul_utils_asr.py | 4 +++- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index dcd8344f730e..3eae47e099cf 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -66,10 +66,6 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: bg_spk_targets = [] if self.inference_mode: - speaker_targets = [ - speaker_to_target(cut, self.num_sample_per_mel_frame, self.num_mel_frame_per_asr_frame) for cut in cuts - ] - spk_targets = collate_matrices(speaker_targets, padding_value=0) return audio, audio_lens, None, None, None, None for idx, cut in enumerate(cuts): diff --git a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py index 6a6f1ee3d73a..d9629e4ab898 100644 --- a/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py +++ b/nemo/collections/asr/parts/mixins/multitalker_asr_mixins.py @@ -29,10 +29,9 @@ def get_spk_kernel_class(spk_kernel_type, input_size, d_model, dropout=0.5): return nn.Sequential( nn.Linear(input_size, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, input_size) ) - elif spk_kernel_type == 'conv2d': - return - elif spk_kernel_type == 'mha': - return + else: + raise ValueError(f"Invalid speaker kernel type: {spk_kernel_type}") + # TODO: conv2d and mha speaker kernel classes class SpeakerKernelMixin(ABC): diff --git a/tests/collections/asr/utils/test_data_simul_utils_asr.py b/tests/collections/asr/utils/test_data_simul_utils_asr.py index 9a27820cdfa1..40c1c49f912f 100644 --- a/tests/collections/asr/utils/test_data_simul_utils_asr.py +++ b/tests/collections/asr/utils/test_data_simul_utils_asr.py @@ -410,7 +410,7 @@ def test_create_new_json_entry(self, annotator): def test_create_new_ctm_entry(self, annotator): words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) session_name = 'test_session' - ctm_list = annotator.create_new_ctm_entry( + ctm_list, word_and_ts_list = annotator.create_new_ctm_entry( words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] ) assert ctm_list[0] == ( @@ -424,6 +424,7 @@ def test_create_new_ctm_entry(self, annotator): conf=None, type_of_token='lex', speaker=speaker_id, + output_precision=annotator._params.data_simulator.outputs.output_precision, ), ) assert ctm_list[1] == ( @@ -437,6 +438,7 @@ def test_create_new_ctm_entry(self, annotator): conf=None, type_of_token='lex', speaker=speaker_id, + output_precision=annotator._params.data_simulator.outputs.output_precision, ), ) From 8183105ce7ac45a1b09e7fefb74505317dae2d57 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Sun, 12 Oct 2025 06:28:03 -0700 Subject: [PATCH 17/29] remove collate_matrices Signed-off-by: Weiqing Wang --- nemo/collections/asr/data/audio_to_text_lhotse_speaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py index 3eae47e099cf..04302c9cb611 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_speaker.py @@ -17,7 +17,7 @@ import torch.utils.data from lhotse.dataset import AudioSamples -from lhotse.dataset.collation import collate_matrices, collate_vectors +from lhotse.dataset.collation import collate_vectors from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.asr.parts.utils.asr_multispeaker_utils import speaker_to_target From 2c84cfe19b692533e9e87ce6432e2a6fb5de58d9 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Sun, 12 Oct 2025 19:25:21 -0700 Subject: [PATCH 18/29] adding training script for mt-parakeet Signed-off-by: Weiqing Wang --- .../speech_to_text_mt_rnnt_bpe.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py diff --git a/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py b/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py new file mode 100644 index 000000000000..8d77ecfad602 --- /dev/null +++ b/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_rnnt_spk_bpe.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + trainer.devices=-1 \ + trainer.accelerator="gpu" \ + trainer.strategy="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecMultiTalkerRNNTBPEModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="experimental/contextnet_rnnt", config_name="config_rnnt_bpe") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecMultiTalkerRNNTBPEModel(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file From b433f9f2f65a179f9e866eb43cd07a41f3fb5ff3 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Mon, 13 Oct 2025 02:26:07 +0000 Subject: [PATCH 19/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py b/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py index 8d77ecfad602..2eb58e5617f8 100644 --- a/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py +++ b/examples/asr/asr_transducer/speech_to_text_mt_rnnt_bpe.py @@ -86,4 +86,4 @@ def main(cfg): if __name__ == '__main__': - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter From 5db7c27fea10e82da0b1ec542bb94cb00348dbe2 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Mon, 13 Oct 2025 13:52:43 -0700 Subject: [PATCH 20/29] fixing create_new_ctm_entry & adding support for batch inference Signed-off-by: Weiqing Wang --- .../speech_to_text_multitalker_streaming_infer.py | 14 +++++++++----- nemo/collections/asr/data/data_simulation.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 5928a3f6b62a..a7a3c98a32ee 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -395,7 +395,6 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: samples, rttms_mask_mats = get_multi_talker_samples_from_manifest( cfg, manifest_file=cfg.manifest_file, feat_per_sec=feat_per_sec, max_spks=cfg.max_num_of_spks ) - cfg.batch_size = len(samples) # Note: rttms_mask_mats contains PyTorch tensors, so we pass it directly instead of storing in config if cfg.spk_supervision == "rttm": diar_model.add_rttms_mask_mats(rttms_mask_mats, device=asr_model.device) @@ -408,7 +407,10 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) + seglst_dict_list = [] + batch_samples = [] for sample_idx, sample in enumerate(samples): + batch_samples.append(sample) streaming_buffer.append_audio_file(sample['audio_filepath'], stream_id=-1) logging.info(f'Added this sample to the buffer: {sample["audio_filepath"]}') @@ -422,6 +424,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: streaming_buffer=streaming_buffer, pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) + multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=batch_samples) else: multispk_asr_streamer = launch_serial_streaming( cfg=cfg, @@ -429,18 +432,19 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model=diar_model, streaming_buffer=streaming_buffer, ) + multispk_asr_streamer.generate_seglst_dicts_from_serial_streaming(samples=batch_samples) + seglst_dict_list.extend(multispk_asr_streamer.instance_manager.seglst_dict_list) streaming_buffer.reset_buffer() + batch_samples = [] if cfg.output_path is not None and multispk_asr_streamer is not None: if cfg.parallel_speaker_strategy: - multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples) write_seglst_file( - seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, output_path=cfg.output_path + seglst_dict_list=seglst_dict_list, output_path=cfg.output_path ) else: - multispk_asr_streamer.generate_seglst_dicts_from_serial_streaming(samples=samples) write_seglst_file( - seglst_dict_list=multispk_asr_streamer.instance_manager.seglst_dict_list, output_path=cfg.output_path + seglst_dict_list=seglst_dict_list, output_path=cfg.output_path ) diff --git a/nemo/collections/asr/data/data_simulation.py b/nemo/collections/asr/data/data_simulation.py index 369a95d9ee9a..1ff18528ff24 100644 --- a/nemo/collections/asr/data/data_simulation.py +++ b/nemo/collections/asr/data/data_simulation.py @@ -1117,7 +1117,7 @@ def _generate_session( ) self.annotator.annote_lists['json'].append(new_json_entry) - new_ctm_entries = self.annotator.create_new_ctm_entry( + new_ctm_entries, _ = self.annotator.create_new_ctm_entry( words=self._words, alignments=self._alignments, session_name=filename, @@ -1643,7 +1643,7 @@ def _generate_session( ) self.annotator.annote_lists['json'].append(new_json_entry) - new_ctm_entries = self.annotator.create_new_ctm_entry( + new_ctm_entries, _ = self.annotator.create_new_ctm_entry( filename, speaker_ids[speaker_turn], start / self._params.data_simulator.sr ) self.annotator.annote_lists['ctm'].extend(new_ctm_entries) From 07600f8a63b5a24e3eb822d62f44c743ee4f980a Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Mon, 13 Oct 2025 20:53:27 +0000 Subject: [PATCH 21/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- .../speech_to_text_multitalker_streaming_infer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index a7a3c98a32ee..418729846562 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -439,13 +439,9 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.output_path is not None and multispk_asr_streamer is not None: if cfg.parallel_speaker_strategy: - write_seglst_file( - seglst_dict_list=seglst_dict_list, output_path=cfg.output_path - ) + write_seglst_file(seglst_dict_list=seglst_dict_list, output_path=cfg.output_path) else: - write_seglst_file( - seglst_dict_list=seglst_dict_list, output_path=cfg.output_path - ) + write_seglst_file(seglst_dict_list=seglst_dict_list, output_path=cfg.output_path) if __name__ == '__main__': From ba9a3714600da7c55229e9e5de859cd995f9ddbe Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Mon, 13 Oct 2025 14:37:25 -0700 Subject: [PATCH 22/29] solving F541 f-string is missing placeholders Signed-off-by: Weiqing Wang --- nemo/collections/asr/data/data_simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/data/data_simulation.py b/nemo/collections/asr/data/data_simulation.py index 1ff18528ff24..921d733bce5e 100644 --- a/nemo/collections/asr/data/data_simulation.py +++ b/nemo/collections/asr/data/data_simulation.py @@ -1190,7 +1190,7 @@ def generate_sessions(self, random_seed: int = None): Args: random_seed (int): random seed for reproducibility """ - logging.info(f"Generating Diarization Sessions") + logging.info("Generating Diarization Sessions") if random_seed is None: random_seed = self._params.data_simulator.random_seed np.random.seed(random_seed) From a93382caaa89305416c1495ce9559defe9416bb1 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Mon, 13 Oct 2025 15:01:59 -0700 Subject: [PATCH 23/29] solving F541 f-string is missing placeholders Signed-off-by: Weiqing Wang --- nemo/collections/asr/data/data_simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/data/data_simulation.py b/nemo/collections/asr/data/data_simulation.py index 921d733bce5e..215571e4a6c9 100644 --- a/nemo/collections/asr/data/data_simulation.py +++ b/nemo/collections/asr/data/data_simulation.py @@ -629,7 +629,7 @@ def _check_missing_speakers(self, num_missing: int = 0): if num_missing != 0: warnings.warn( f"{self._params.data_simulator.session_config.num_speakers - num_missing}" - f"speakers were included in the clip instead of the requested amount of " + "speakers were included in the clip instead of the requested amount of " f"{self._params.data_simulator.session_config.num_speakers}" ) From 59dd795bfec75169a96f03f4229ee1051d1bcb96 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Tue, 14 Oct 2025 13:56:03 -0700 Subject: [PATCH 24/29] fix missing part in DiarizationConfig Signed-off-by: Weiqing Wang --- ...ech_to_text_multitalker_streaming_infer.py | 28 ++++++++++--------- .../parts/utils/multispk_transcribe_utils.py | 10 +++---- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 418729846562..2813ae64ef1b 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -36,13 +36,15 @@ @dataclass class DiarizationConfig: # Required configs - diar_model_path: Optional[str] = None # Path to a .nemo file + diar_model: Optional[str] = None # Path to a .nemo file diar_pretrained_name: Optional[str] = None # Name of a pretrained model - max_num_of_spks: Optional[int] = 4 - parallel_speaker_strategy: bool = True - masked_asr: bool = True - mask_preencode: bool = False - single_speaker_model: bool = False + max_num_of_spks: Optional[int] = 4 # maximum number of speakers + parallel_speaker_strategy: bool = True # whether to use parallel speaker strategy + masked_asr: bool = True # whether to use masked ASR + mask_preencode: bool = False # whether to mask preencode or mask features + cache_gating: bool = True # whether to use cache gating + cache_gating_buffer_size: int = 2 # buffer size for cache gating + single_speaker_mode: bool = False # whether to use single speaker mode # General configs session_len_sec: float = -1 # End-to-end diarization session length in seconds @@ -229,8 +231,8 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.random_seed: pl.seed_everything(cfg.random_seed) - if cfg.diar_model_path is None and cfg.diar_pretrained_name is None: - raise ValueError("Both cfg.diar_model_path and cfg.pretrained_name cannot be None!") + if cfg.diar_model is None and cfg.diar_pretrained_name is None: + raise ValueError("Both cfg.diar_model and cfg.pretrained_name cannot be None!") if cfg.audio_file is None and cfg.manifest_file is None: raise ValueError("Both cfg.audio_file and cfg.manifest_file cannot be None!") @@ -254,14 +256,14 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: accelerator = 'gpu' map_location = torch.device(f'cuda:{cfg.cuda}') - if cfg.diar_model_path.endswith(".ckpt"): + if cfg.diar_model.endswith(".ckpt"): diar_model = SortformerEncLabelModel.load_from_checkpoint( - checkpoint_path=cfg.diar_model_path, map_location=map_location, strict=False + checkpoint_path=cfg.diar_model, map_location=map_location, strict=False ) - elif cfg.diar_model_path.endswith(".nemo"): - diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model_path, map_location=map_location) + elif cfg.diar_model.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model, map_location=map_location) else: - raise ValueError("cfg.diar_model_path must end with.ckpt or.nemo!") + raise ValueError("cfg.diar_model must end with.ckpt or.nemo!") # Model setup for inference trainer = pl.Trainer(devices=device, accelerator=accelerator) diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 180ede5d705f..6f9f69c59b13 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -494,7 +494,7 @@ def __init__( self._masked_asr = cfg.get("masked_asr", True) self._use_mask_preencode = cfg.get("mask_preencode", False) - self._single_speaker_model = cfg.get("single_speaker_model", False) + self._single_speaker_mode = cfg.get("single_speaker_mode", False) self.instance_manager = MultiTalkerInstanceManager( asr_model=self.asr_model, @@ -1196,13 +1196,13 @@ def perform_parallel_streaming_stt_spk( # For a session, if no second speaker is detected, # the spk_targets will be set to all ones in the single speaker mode - if self._single_speaker_model: + if self._single_speaker_mode: if self._max_num_of_spks == 1: is_single_speaker = [True] * chunk_audio.shape[0] else: - is_single_speaker = (new_diar_pred_out_stream[:, :, : self._max_num_of_spks] > 0.5).any(1).sum( - -1 - ) <= 1.0 + is_single_speaker = ( + new_diar_pred_out_stream > 0.5 + ).any(1).sum(-1) <= 1.0 for i in range(chunk_audio.shape[0]): if is_single_speaker[i]: new_diar_pred_out_stream[i, :, 0] = 1.0 From cb24acedbbb0534e2e8c128de330eb10b83fd6ef Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Tue, 14 Oct 2025 20:56:44 +0000 Subject: [PATCH 25/29] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- .../speech_to_text_multitalker_streaming_infer.py | 14 +++++++------- .../asr/parts/utils/multispk_transcribe_utils.py | 4 +--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 2813ae64ef1b..aba8a2c31de1 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -38,13 +38,13 @@ class DiarizationConfig: # Required configs diar_model: Optional[str] = None # Path to a .nemo file diar_pretrained_name: Optional[str] = None # Name of a pretrained model - max_num_of_spks: Optional[int] = 4 # maximum number of speakers - parallel_speaker_strategy: bool = True # whether to use parallel speaker strategy - masked_asr: bool = True # whether to use masked ASR - mask_preencode: bool = False # whether to mask preencode or mask features - cache_gating: bool = True # whether to use cache gating - cache_gating_buffer_size: int = 2 # buffer size for cache gating - single_speaker_mode: bool = False # whether to use single speaker mode + max_num_of_spks: Optional[int] = 4 # maximum number of speakers + parallel_speaker_strategy: bool = True # whether to use parallel speaker strategy + masked_asr: bool = True # whether to use masked ASR + mask_preencode: bool = False # whether to mask preencode or mask features + cache_gating: bool = True # whether to use cache gating + cache_gating_buffer_size: int = 2 # buffer size for cache gating + single_speaker_mode: bool = False # whether to use single speaker mode # General configs session_len_sec: float = -1 # End-to-end diarization session length in seconds diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 6f9f69c59b13..4b55e7f1c7cd 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -1200,9 +1200,7 @@ def perform_parallel_streaming_stt_spk( if self._max_num_of_spks == 1: is_single_speaker = [True] * chunk_audio.shape[0] else: - is_single_speaker = ( - new_diar_pred_out_stream > 0.5 - ).any(1).sum(-1) <= 1.0 + is_single_speaker = (new_diar_pred_out_stream > 0.5).any(1).sum(-1) <= 1.0 for i in range(chunk_audio.shape[0]): if is_single_speaker[i]: new_diar_pred_out_stream[i, :, 0] = 1.0 From 626c4745b46a84bb1c760f6c774ba54a4a1fbdd6 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Tue, 14 Oct 2025 15:40:03 -0700 Subject: [PATCH 26/29] solving empty strings Signed-off-by: Weiqing Wang --- .../collections/asr/parts/utils/multispk_transcribe_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 6f9f69c59b13..dbc8772a3df4 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -1536,9 +1536,10 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float): # Case 1 - If start_tiime is greater than end_time + sent_break_sec, then we need to add the sentence if sep_flag or (last_end_time == 0.0 or start_time > last_end_time + self._sent_break_sec): - if len(diff_text) > 0 and diff_text.strip()[0] in ['.', ',', '?', '!']: + stripped_text = diff_text.strip() + if len(stripped_text) > 0 and stripped_text[0] in ['.', ',', '?', '!']: # This handles the case where the first character should be assigned to the previous sentence. - the_first_char, diff_text = diff_text.strip()[0], diff_text.strip()[1:] + the_first_char, diff_text = stripped_text[0], stripped_text[1:] self._update_last_sentence(spk_idx=spk_idx, end_time=None, diff_text=the_first_char) self._speaker_wise_sentences[spk_idx].append( get_new_sentence_dict( From 01353a58082caffd5f51b69ea74f871f2b008443 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 14 Oct 2025 16:38:55 -0700 Subject: [PATCH 27/29] Adding bugfix on duplicated launcher function Signed-off-by: taejinp --- ...ech_to_text_multitalker_streaming_infer.py | 106 +++++++----------- .../parts/utils/multispk_transcribe_utils.py | 100 +++++++++++++---- 2 files changed, 120 insertions(+), 86 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index aba8a2c31de1..78d488663eb8 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import math import time from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union @@ -27,6 +25,8 @@ from nemo.collections.asr.parts.utils.multispk_transcribe_utils import ( SpeakerTaggedASR, get_multi_talker_samples_from_manifest, + add_delay_for_real_time, + write_seglst_file, ) from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer from nemo.core.config import hydra_runner @@ -34,7 +34,10 @@ @dataclass -class DiarizationConfig: +class MultitalkerTranscriptionConfig: + """ + Configuration for Multi-talker transcription with an ASR model and a diarization model. + """ # Required configs diar_model: Optional[str] = None # Path to a .nemo file diar_pretrained_name: Optional[str] = None # Name of a pretrained model @@ -107,53 +110,6 @@ class DiarizationConfig: spk_supervision: str = "diar" # ["diar", "rttm"] binary_diar_preds: bool = False - -def format_time(seconds): - minutes = math.floor(seconds / 60) - sec = seconds % 60 - return f"{minutes}:{sec:05.2f}" - - -def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded): - # for the first step there is no need to drop any tokens after the downsampling as no caching is being used - if step_num == 0 and not pad_and_drop_preencoded: - return 0 - else: - return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded - - -def add_delay_for_real_time(cfg, chunk_audio, session_start_time, feat_frame_count, loop_end_time, loop_start_time): - """ - Add artificial delay for real-time mode by calculating the time difference between - the current time and the session start time.. - - Args: - cfg (DiarizationConfig): The configuration object. - """ - time_diff = max(0, (time.time() - session_start_time) - feat_frame_count * cfg.feat_len_sec) - eta_min_sec = format_time(time.time() - session_start_time) - logging.info( - f"[ REAL TIME MODE ] min:sec - {eta_min_sec} " - f"Time difference for real-time mode: {time_diff:.4f} seconds" - ) - time.sleep( - max( - 0, - (chunk_audio.shape[-1] - cfg.discarded_frames) * cfg.feat_len_sec - - (loop_end_time - loop_start_time) - - time_diff * cfg.finetune_realtime_ratio, - ) - ) - - -def write_seglst_file(seglst_dict_list, output_path): - if len(seglst_dict_list) == 0: - raise ValueError("seglst_dict_list is empty. No transcriptions were generated.") - with open(output_path, 'w') as f: - f.write(json.dumps(seglst_dict_list, indent=4) + '\n') - logging.info(f"Saved the transcriptions of the streaming inference in\n:{output_path}") - - def launch_serial_streaming( cfg, asr_model, @@ -161,14 +117,24 @@ def launch_serial_streaming( streaming_buffer, pad_and_drop_preencoded=False, ): + """ + Launch the serial streaming inference with ASR model and diarization model. + + Args: + cfg (Any): The configuration object containing the parameters for the streaming inference. + asr_model (Any): The ASR model loaded from the path provided in MultitalkerTranscriptionConfig. + diar_model (Any): The diarization model loadded from the path provided in MultitalkerTranscriptionConfig. + streaming_buffer: An iterator that yields chunks of audio data and their lengths. + pad_and_drop_preencoded: A boolean flag indicating whether to pad and drop the extra pre-encoded tokens. + """ streaming_buffer_iter = iter(streaming_buffer) multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) feat_frame_count = 0 - session_start_time = time.time() for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): - drop_extra_pre_encoded = calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded) + drop_extra_pre_encoded = (0 if step_num == 0 and not pad_and_drop_preencoded + else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded) loop_start_time = time.time() with torch.inference_mode(): with autocast: @@ -180,17 +146,16 @@ def launch_serial_streaming( is_buffer_empty=streaming_buffer.is_buffer_empty(), drop_extra_pre_encoded=drop_extra_pre_encoded, ) - - feat_frame_count += chunk_audio.shape[-1] - cfg.discarded_frames if cfg.real_time_mode: add_delay_for_real_time( - cfg, + cfg=cfg, chunk_audio=chunk_audio, session_start_time=session_start_time, feat_frame_count=feat_frame_count, loop_end_time=time.time(), loop_start_time=loop_start_time, ) + feat_frame_count += chunk_audio.shape[-1] - cfg.discarded_frames return multispk_asr_streamer @@ -203,13 +168,15 @@ def launch_parallel_streaming( ): streaming_buffer_iter = iter(streaming_buffer) multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) - + feat_frame_count = 0 + session_start_time = time.time() for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): - # logging.info(f"Step ID: {step_num}") + drop_extra_pre_encoded = (0 if step_num == 0 and not pad_and_drop_preencoded + else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded) + loop_start_time = time.time() with torch.inference_mode(): with autocast: with torch.no_grad(): - drop_extra_pre_encoded = calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded) multispk_asr_streamer.perform_parallel_streaming_stt_spk( step_num=step_num, chunk_audio=chunk_audio, @@ -217,11 +184,21 @@ def launch_parallel_streaming( is_buffer_empty=streaming_buffer.is_buffer_empty(), drop_extra_pre_encoded=drop_extra_pre_encoded, ) + if cfg.real_time_mode: + add_delay_for_real_time( + cfg=cfg, + chunk_audio=chunk_audio, + session_start_time=session_start_time, + feat_frame_count=feat_frame_count, + loop_end_time=time.time(), + loop_start_time=loop_start_time, + ) + feat_frame_count += chunk_audio.shape[-1] - cfg.discarded_frames return multispk_asr_streamer -@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) -def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: +@hydra_runner(config_name="MultitalkerTranscriptionConfig", schema=MultitalkerTranscriptionConfig) +def main(cfg: MultitalkerTranscriptionConfig) -> Union[MultitalkerTranscriptionConfig]: for key in cfg: cfg[key] = None if cfg[key] == 'None' else cfg[key] @@ -361,6 +338,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: else: online_normalization = False + seglst_dict_list = [] if cfg.audio_file is not None: # Stream a single audio file samples = [ @@ -376,14 +354,14 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: cfg.batch_size = len(samples) streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1) if cfg.parallel_speaker_strategy: - multispk_asr_streamer = launch_serial_streaming( + multispk_asr_streamer = launch_parallel_streaming( cfg=cfg, asr_model=asr_model, diar_model=diar_model, streaming_buffer=streaming_buffer, pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) - + multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples) else: multispk_asr_streamer = launch_serial_streaming( cfg=cfg, @@ -391,6 +369,9 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model=diar_model, streaming_buffer=streaming_buffer, ) + multispk_asr_streamer.generate_seglst_dicts_from_serial_streaming(samples=samples) + seglst_dict_list.extend(multispk_asr_streamer.instance_manager.seglst_dict_list) + else: # Stream audio files in a manifest file in batched mode feat_per_sec = round(asr_model.cfg.preprocessor.window_stride * asr_model.cfg.encoder.subsampling_factor, 2) @@ -409,7 +390,6 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) - seglst_dict_list = [] batch_samples = [] for sample_idx, sample in enumerate(samples): batch_samples.append(sample) diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 4b55e7f1c7cd..5301ca9b2863 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -16,6 +16,7 @@ import json import os import time +import math from collections import OrderedDict from copy import deepcopy from functools import wraps @@ -51,7 +52,6 @@ def measure_eta(func): Returns: callable: The wrapped function. """ - @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() # Record the start time @@ -64,6 +64,81 @@ def wrapper(*args, **kwargs): return wrapper + +def format_time(seconds: float) -> str: + """ + Format the time in minutes and seconds. + + Args: + seconds (float): The time in seconds. + + Returns: + str: The time in minutes and seconds. + """ + minutes = math.floor(seconds / 60) + sec = seconds % 60 + return f"{minutes}:{sec:05.2f}" + +def add_delay_for_real_time( + cfg: Any, + chunk_audio: torch.Tensor, + session_start_time: float, + feat_frame_count: int, + loop_end_time: float, + loop_start_time: float +): + """ + Add artificial delay for real-time mode by calculating the time difference between + the current time and the session start time.. + + Args: + cfg (Any): The configuration object containing the parameters for the delay calculation. + chunk_audio (torch.Tensor): The chunk audio tensor containing time-series audio data. + session_start_time (float): The session start time in seconds. + feat_frame_count (int): The number of features per second. + loop_end_time (float): The loop end time in seconds. + loop_start_time (float): The loop start time in seconds. + """ + time_diff = max(0, (time.time() - session_start_time) - feat_frame_count * cfg.feat_len_sec) + eta_min_sec = format_time(time.time() - session_start_time) + logging.info( + f"[ REAL TIME MODE ] min:sec - {eta_min_sec} " + f"Time difference for real-time mode: {time_diff:.4f} seconds" + ) + time.sleep( + max( + 0, + (chunk_audio.shape[-1] - cfg.discarded_frames) * cfg.feat_len_sec + - (loop_end_time - loop_start_time) + - time_diff * cfg.finetune_realtime_ratio, + ) + ) + +def write_seglst_file(seglst_dict_list: List[Dict[str, Any]], output_path: str): + """ + Write a seglst file from the seglst dictionary list. + + Args: + seglst_dict_list (List[Dict[str, Any]]): The list of seglst dictionaries. + Example: + [ + { + "session_id": "session_001", + "speaker": "speaker_1", + "words": "Write this to a SegLST file.", + "start_time": 12.34, + "end_time": 23.45, + }, ... + ] + output_path (str): The path to the output file. + """ + if len(seglst_dict_list) == 0: + raise ValueError("seglst_dict_list is empty. No transcriptions were generated.") + with open(output_path, 'w') as f: + f.write(json.dumps(seglst_dict_list, indent=4) + '\n') + logging.info(f"Saved the transcriptions of the streaming inference in\n:{output_path}") + + def get_multi_talker_samples_from_manifest(cfg, manifest_file: str, feat_per_sec: float, max_spks: int): """ Get the multi-talker samples from the manifest file and save it to a list named 'samples'. @@ -176,27 +251,6 @@ def get_new_sentence_dict( 'session_id': session_id, } - -def calc_drop_extra_pre_encoded(asr_model: SortformerEncLabelModel, step_num: int, pad_and_drop_preencoded: bool): - """ - Calculate the number of extra tokens to drop after the downsampling. - - Args: - asr_model (SortformerEncLabelModel): The ASR model. - step_num (int): The step number. - pad_and_drop_preencoded (bool): Whether to pad and drop the extra pre-encoded tokens. - - Returns: - int: The number of extra tokens to drop. - """ - # for the first step there is no need to drop any tokens - # after the downsampling as no caching is being used - if step_num == 0 and not pad_and_drop_preencoded: - return 0 - else: - return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded - - def fix_frame_time_step(cfg: Any, new_tokens: List[str], new_words: List[str], frame_inds_seq: List[int]) -> List[int]: """ Adjust the frame indices sequence to match the length of new tokens. @@ -1144,7 +1198,7 @@ def perform_serial_streaming_stt_spk( previous_hypotheses=previous_hypotheses[batch_idx], previous_pred_out=asr_pred_out_stream[batch_idx], ) - + @measure_eta def perform_parallel_streaming_stt_spk( self, From 8009d9bc9d950efee9353b2ac8a257de6ca1f3f3 Mon Sep 17 00:00:00 2001 From: tango4j Date: Tue, 14 Oct 2025 23:39:49 +0000 Subject: [PATCH 28/29] Apply isort and black reformatting Signed-off-by: tango4j --- ...ech_to_text_multitalker_streaming_infer.py | 20 +++++++++++++------ .../parts/utils/multispk_transcribe_utils.py | 15 ++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py index 78d488663eb8..7ff0506c158f 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_multitalker_streaming_infer.py @@ -24,8 +24,8 @@ from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.parts.utils.multispk_transcribe_utils import ( SpeakerTaggedASR, - get_multi_talker_samples_from_manifest, add_delay_for_real_time, + get_multi_talker_samples_from_manifest, write_seglst_file, ) from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer @@ -38,6 +38,7 @@ class MultitalkerTranscriptionConfig: """ Configuration for Multi-talker transcription with an ASR model and a diarization model. """ + # Required configs diar_model: Optional[str] = None # Path to a .nemo file diar_pretrained_name: Optional[str] = None # Name of a pretrained model @@ -110,6 +111,7 @@ class MultitalkerTranscriptionConfig: spk_supervision: str = "diar" # ["diar", "rttm"] binary_diar_preds: bool = False + def launch_serial_streaming( cfg, asr_model, @@ -117,7 +119,7 @@ def launch_serial_streaming( streaming_buffer, pad_and_drop_preencoded=False, ): - """ + """ Launch the serial streaming inference with ASR model and diarization model. Args: @@ -133,8 +135,11 @@ def launch_serial_streaming( feat_frame_count = 0 session_start_time = time.time() for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): - drop_extra_pre_encoded = (0 if step_num == 0 and not pad_and_drop_preencoded - else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded) + drop_extra_pre_encoded = ( + 0 + if step_num == 0 and not pad_and_drop_preencoded + else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + ) loop_start_time = time.time() with torch.inference_mode(): with autocast: @@ -171,8 +176,11 @@ def launch_parallel_streaming( feat_frame_count = 0 session_start_time = time.time() for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): - drop_extra_pre_encoded = (0 if step_num == 0 and not pad_and_drop_preencoded - else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded) + drop_extra_pre_encoded = ( + 0 + if step_num == 0 and not pad_and_drop_preencoded + else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded + ) loop_start_time = time.time() with torch.inference_mode(): with autocast: diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 5841050f9b66..4168e49c9734 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -14,9 +14,9 @@ import itertools import json +import math import os import time -import math from collections import OrderedDict from copy import deepcopy from functools import wraps @@ -52,6 +52,7 @@ def measure_eta(func): Returns: callable: The wrapped function. """ + @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() # Record the start time @@ -64,9 +65,8 @@ def wrapper(*args, **kwargs): return wrapper - def format_time(seconds: float) -> str: - """ + """ Format the time in minutes and seconds. Args: @@ -79,13 +79,14 @@ def format_time(seconds: float) -> str: sec = seconds % 60 return f"{minutes}:{sec:05.2f}" + def add_delay_for_real_time( cfg: Any, chunk_audio: torch.Tensor, session_start_time: float, feat_frame_count: int, loop_end_time: float, - loop_start_time: float + loop_start_time: float, ): """ Add artificial delay for real-time mode by calculating the time difference between @@ -114,8 +115,9 @@ def add_delay_for_real_time( ) ) + def write_seglst_file(seglst_dict_list: List[Dict[str, Any]], output_path: str): - """ + """ Write a seglst file from the seglst dictionary list. Args: @@ -251,6 +253,7 @@ def get_new_sentence_dict( 'session_id': session_id, } + def fix_frame_time_step(cfg: Any, new_tokens: List[str], new_words: List[str], frame_inds_seq: List[int]) -> List[int]: """ Adjust the frame indices sequence to match the length of new tokens. @@ -1198,7 +1201,7 @@ def perform_serial_streaming_stt_spk( previous_hypotheses=previous_hypotheses[batch_idx], previous_pred_out=asr_pred_out_stream[batch_idx], ) - + @measure_eta def perform_parallel_streaming_stt_spk( self, From d5fdcbfd887301c069637fe407d1ed5fe09d508c Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Tue, 14 Oct 2025 17:08:37 -0700 Subject: [PATCH 29/29] Won't raise an error if use_lhotse is set to false Signed-off-by: Weiqing Wang --- nemo/collections/asr/models/multitalker_asr_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/collections/asr/models/multitalker_asr_models.py b/nemo/collections/asr/models/multitalker_asr_models.py index b93c017422c2..238db5178def 100644 --- a/nemo/collections/asr/models/multitalker_asr_models.py +++ b/nemo/collections/asr/models/multitalker_asr_models.py @@ -63,8 +63,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): tokenizer=self.tokenizer, ), ) - else: - raise ValueError("Only lhotse dataloader is supported for multitalker models") def training_step(self, batch, batch_nb): """Training step with speaker targets."""