[SpeechLM] model, preprocessor and collect_stats#6279
[SpeechLM] model, preprocessor and collect_stats#6279jctian98 merged 7 commits intoespnet:masterfrom
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces several core components for the SpeechLM module, including abstract job templates, a SpeechLM-specific job implementation, a parallel HuggingFace model wrapper, and data processing utilities. The changes are substantial and lay the groundwork for multimodal training. My review identified several critical issues in the new parallel.py and speechlm_job.py files. These include bugs that can lead to application crashes during model initialization or data preprocessing, as well as incorrect logic in the loss and accuracy calculations. Addressing these issues is crucial for the stability and correctness of the new module.
| cur_start, _ = intervals[0] | ||
| # Split intervals if they exceed max_loss_interval size | ||
| for _, end in intervals[1:]: | ||
| if end - cur_start <= max_loss_interval: | ||
| continue | ||
| else: | ||
| model.loss_intervals.append((cur_start, end)) | ||
| cur_start = end | ||
|
|
||
| # Add final interval if any tokens remain | ||
| if end > cur_start: | ||
| model.loss_intervals.append((cur_start, end)) |
There was a problem hiding this comment.
The logic for creating loss_intervals is buggy and will crash with an UnboundLocalError if a discrete modality has only one stream (i.e., len(intervals) == 1). This happens because the loop over intervals[1:] is not entered, leaving the end variable undefined for line 159. Furthermore, the current logic is flawed as it seems to incorrectly merge non-contiguous vocabulary intervals, which is not what the comment "Split large vocabularies into smaller intervals" implies. A more robust approach is to iterate through each vocabulary interval and split it if it exceeds max_loss_interval.
for start, end in intervals:
# Split a large interval into smaller chunks
while end - start > max_loss_interval:
model.loss_intervals.append((start, start + max_loss_interval))
start += max_loss_interval
if end > start:
model.loss_intervals.append((start, end))| if this_mask.int().sum() == 0: | ||
| continue | ||
| # Compute loss only for vocabulary subset [start:end] | ||
| this_logits = hidden_states[this_mask] |
There was a problem hiding this comment.
In the _loss method, you are indexing hidden_states with this_mask. However, hidden_states has a shape of [B, T-1, S, H], while this_mask (for streams > 0) has a shape of [B, T-1, S-1]. This dimension mismatch will cause an indexing error. You should first slice hidden_states to select the streams from 1 onwards before applying the mask.
this_logits = hidden_states[:, :, 1:][this_mask]| self.num_stream = max( | ||
| [io.num_stream() for io in multimodal_io.values() if io.is_discrete] | ||
| ) |
There was a problem hiding this comment.
The calculation of self.num_stream will raise a ValueError if multimodal_io does not contain any discrete IOs. In this scenario, the list comprehension will be empty, and calling max() on an empty sequence is an error. You should handle this case to avoid a crash, for instance by providing a default value if no discrete IOs are present.
| self.num_stream = max( | |
| [io.num_stream() for io in multimodal_io.values() if io.is_discrete] | |
| ) | |
| self.num_stream = max( | |
| [io.num_stream() for io in multimodal_io.values() if io.is_discrete] or [0] | |
| ) |
| apply_eots = [ | ||
| msg1[0] == msg2[0] for msg1, msg2 in zip(messages[:1], messages[1:]) | ||
| ] + [False] |
There was a problem hiding this comment.
The logic for generating apply_eots is incorrect. zip(messages[:1], messages[1:]) only compares the first two messages in the sequence. Consequently, apply_eots will have a length of at most 2. When this is zipped with messages, the loop will only process the first two messages, and any subsequent messages in the dialogue will be ignored. To compare all adjacent messages, you should zip messages[:-1] with messages[1:].
| apply_eots = [ | |
| msg1[0] == msg2[0] for msg1, msg2 in zip(messages[:1], messages[1:]) | |
| ] + [False] | |
| apply_eots = [ | |
| msg1[0] == msg2[0] for msg1, msg2 in zip(messages[:-1], messages[1:]) | |
| ] + [False] |
| ) | ||
| loss[:, :, 1:].masked_scatter_(this_mask, this_loss) | ||
| if not self.training: | ||
| this_acc = this_logits.argmax(-1) == this_targets - start |
There was a problem hiding this comment.
The accuracy calculation for streams 1 and onwards is incorrect. The this_targets variable is already adjusted to be relative to the interval's start (residual_ids[this_mask] - start). Subtracting start again when comparing with the argmax result will lead to incorrect accuracy metrics during evaluation.
this_acc = this_logits.argmax(-1) == this_targets
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #6279 +/- ##
=======================================
Coverage 56.48% 56.48%
=======================================
Files 898 898
Lines 84823 84823
=======================================
Hits 47914 47914
Misses 36909 36909
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…nto speechlm_model merge remove
|
Have reflected the review comments of Gemini. |
|
@jctian98 it might be helpful to add line at beginning of each file (including [espnet2/speechlm/model/speechlm/lm/parallel.py]) to describe what file does for future development. |
See https://github.com/espnet/espnet/tree/master/espnet2/tts as a reference. |
This PR adds several core components of SpeechLM module:
(1) abs_job.py: the overall job template abstract class, which users can define the model and preprocessors. By defining different models and preprocessors, the codebase can support various tasks (e.g., SpeechLM, Diffusion, etc.)
(2) speechlm_job.py: the inherited speechlm job template, which provides the building method for the model and preprocessor for SpeechLM. Specifically, it describes how we convert the data dict into the multi-stream training sequences. (aka, the overall preprocessing logics)
(3) parallel.py: the modeling file built from any given HF LLM. It describes the multimodal forward process and loss computing, but it doesn't implement the inference code yet.
(4) task_conf_speechlm.py: the extended task definition. This is directly moved from #6257 by changing some folder structure.
(5) prepare_length_stats.py: the bin file to collect length stats before launching the training. Such stats would be used for batchfy.
Other revisions are minor.
Prior PRs: #6257 , #6258 , #6260 #6278
Request review: @Masao-Someki @wanchichen @siddhu001