Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

Adel-Moumen
Copy link
Collaborator

@Adel-Moumen Adel-Moumen commented Jul 2, 2025

What does this PR do?

This PR aims at providing an alternative proof-of-concept for saving/loading features in SpeechBrain.

Background

SpeechBrain has primarily focused on extracting features on the fly—FBanks, SSL representations, etc.—as part of its philosophy of doing everything in a single train.py file driven by a yaml configuration. This enables rapid prototyping with tight feedback loops.

However, in recent months and years we’ve seen a trend toward ever-larger datasets (e.g., GigaSpeech’s 10 k hours, LibriHeavy’s 50 k hours) becoming the de facto benchmarks for training models (farewell to our old LibriSpeech on V100). The cost of on-the-fly feature extraction grows with multiple epochs over such large corpora. Moreover, a new form of representation has emerged: speech tokens. These discrete representations—often extracted from SSL encoders or VQ-VAE models—are fixed and never changed, and are used by medium- to large-scale autoregressive models. But because these encoders are heavy, extracting tokens on the fly is prohibitively expensive. Instead, tokens are typically extracted offline (much like SentencePiece tokens) and then loaded at training time so that only the decoder (and the tokens) reside in VRAM.

This growing use of frozen features and discrete representations renders SpeechBrain’s current workflow impractical at scale. The community’s embrace of SpeechLMs and SpeechLLMs marks a paradigm shift in which on-the-fly feature extraction is no longer feasible. This PR addresses that challenge by providing a proof of concept for saving and loading pre-extracted features in SpeechBrain.

Note: I welcome discussion and would be very happy to collaboratively develop a SpeechBrain-style prototype.

Description of the Prototype

I extended the Brain class with two new methods: compute_features and cache_features.

  • compute_features(batch, stage)
    Similar to fit_batch, this method takes a batch and a stage, extracts the required features, and returns a list of dictionaries. Each dictionary must include the utterance id plus any feature key/value pairs you want to save. For example, to save id, ssl_feats, and tokens, return:

    [
      {"id": "utt1", "ssl_feats": <tensor>, "tokens": <ndarray>},
      {"id": "utt2", "ssl_feats": <tensor>, "tokens": <ndarray>},
      …
    ]
  • cache_features(...)
    Analogous to fit() or evaluate(), this method iterates over a dataset (or dataloader), calls compute_features on each batch, and writes the returned feature dictionaries to disk.


I/O Backends & Configuration

Inspired by [lhotse’s I/O module](https://github.com/lhotse-speech/lhotse/blob/fda1a986e5e1e72a14c82049b4ee709fc09a81e6/lhotse/features/io.py#L494), I added a feature_io.py file defining reader and writer classes, plus a simple factory. Key points:

  • Pluggable backends: HDF5, NumPy, etc.
  • Compression vs. speed: Choose different HDF5 compressors to trade storage size against I/O throughput.
  • By-utterance storage: Each feature is stored under its utterance key, so you can load only what you need.
  • Memory mapping: Readers use np.memmap-style access to avoid loading everything into RAM.

All configuration lives in YAML via a FeatureStorageConfig section that specifies, for each feature:

  • name: the key under which to store it (e.g., ssl_feats)
  • dtype: e.g., float32
  • writer_class: e.g., NumpyHdf5Writer

YAML Example

feature_configs:
  ssl_feats: !new:speechbrain.dataio.feature_io.FeatureStorageConfig
    name: ssl_feats
    dtype: float32
    writer_class: !name:speechbrain.dataio.feature_io.NumpyHdf5Writer

train_feature_storage_writers: !apply:speechbrain.dataio.feature_io.create_feature_storage_writers
  feature_configs: !ref <feature_configs>
  base_path: !ref <ssl_features_folder>
  prefix: "train_960h"

Usage Example

from dataclasses import dataclass
import speechbrain as sb
from feature_io import create_feature_storage_writers

@dataclass
class FeatureExtractionConfig:
    utterance_id_key: str = "id"
    ssl_key: str = "ssl_feats"

class ExtractFeatures(sb.core.Brain):
    def __init__(self, modules, hparams, run_opts, feature_extraction_config: FeatureExtractionConfig):
        super().__init__(modules=modules, hparams=hparams, run_opts=run_opts)
        self.feature_extraction_config = feature_extraction_config

    def compute_features(self, batch, stage):
        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig
        batch_size = wavs.shape[0]

        # Extract features
        feats = self.modules.wav2vec2(wavs, wav_lens)

        return [
            {
                self.feature_extraction_config.utterance_id_key: batch.id[i],
                self.feature_extraction_config.ssl_key: feats[i],
            }
            for i in range(batch_size)
        ]

if __name__ == "__main__":
    # Prepare hparams, run_opts, etc.
    feature_extractor = ExtractFeatures(
        modules=hparams["modules"],
        hparams=hparams,
        run_opts=run_opts,
        feature_extraction_config=FeatureExtractionConfig(
            utterance_id_key="id",
            ssl_key="ssl_feats",
        ),
    )

    # Initialize writers from YAML config
    writers = create_feature_storage_writers(
        base_path=hparams["ssl_features_folder"],
        prefix="train_960h",
        configs=hparams["feature_configs"],
    )

    # Cache the features
    feature_extractor.cache_features(
        writers,
        train_data,
        loader_kwargs=hparams["dataloader_opts"],
        stage=sb.Stage.TRAIN,
    )

Reading Cached Features in train.py

Define your readers in YAML:

train_feature_readers:
  ssl_feats: !new:speechbrain.dataio.feature_io.FeatureStorageReaderConfig
    name: ssl_feats
    reader_class: !new:speechbrain.dataio.feature_io.NumpyHdf5Reader
      storage_path: !ref <extracted_features_folder>/train_960h_ssl_feats.h5

And use them in your data pipeline:

@sb.utils.data_pipeline.takes("id")
@sb.utils.data_pipeline.provides("feats")
def train_audio_pipeline(id):
    return hparams["train_feature_readers"]["ssl_feats"].read(id)

That’s all! Implement compute_features, configure your writers and readers in YAML, and call cache_features.

Room for Improvements

When handling multiple dataset splits, the dataio_prepare stage can become verbose. For example:

# 2. Define audio pipelines for each split:
@sb.utils.data_pipeline.takes("id")
@sb.utils.data_pipeline.provides("feats")
def train_audio_pipeline(id):
    return hparams["train_feature_readers"]["ssl_feats"].reader_class.read(id)

@sb.utils.data_pipeline.takes("id")
@sb.utils.data_pipeline.provides("feats")
def valid_audio_pipeline(id):
    return hparams["valid_feature_readers"]["ssl_feats"].reader_class.read(id)

@sb.utils.data_pipeline.takes("id")
@sb.utils.data_pipeline.provides("feats")
def test_clean_audio_pipeline(id):
    return hparams["test_clean_feature_readers"]["ssl_feats"].reader_class.read(id)

@sb.utils.data_pipeline.takes("id")
@sb.utils.data_pipeline.provides("feats")
def test_other_audio_pipeline(id):
    return hparams["test_other_feature_readers"]["ssl_feats"].reader_class.read(id)

# Register dynamic items for each split:
sb.dataio.dataset.add_dynamic_item([train_data], train_audio_pipeline)
sb.dataio.dataset.add_dynamic_item([valid_data], valid_audio_pipeline)
sb.dataio.dataset.add_dynamic_item([test_datasets["test-clean"]], test_clean_audio_pipeline)
sb.dataio.dataset.add_dynamic_item([test_datasets["test-other"]], test_other_audio_pipeline)

One way to simplify this would be to move the writer (and reader) instantiations into the Brain class itself, rather than defining them in YAML. That way, you wouldn’t need to clutter your config with:

train_feature_storage_writers: !apply:speechbrain.dataio.feature_io.create_feature_storage_writers
  feature_configs: !ref <feature_configs>
  base_path: !ref <ssl_features_folder>
  prefix: "train_960h"

— the Brain subclass could automatically create and expose feature_writers and feature_readers for each split based on a single feature_configs entry.

NOTE: Please don't ask me about solving tests etc. The intended goal of this PR so far is to provide a PoC. I will make things cleaner etc once we are converging towards a general design.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant