We explore the possibility of maximizing the information represented in spectrograms by making the spectrogram basis functions trainable.
A number of experiments are conducted in which we compare the performance of trainable short-time Fourier transform (STFT) and Mel basis functions provided by FastAudio and nnAudio on two tasks: keyword spotting (KWS) and automatic speech recognition (ASR).
Broadcasting-residual network (BC-ResNet) as well as a Simple model (constructed with a linear layer) are used for these two tasks.
In our experiments, we explore four different training settings:
A Both gMel and gSTFT are non-trainable.
B gMel is trainable while gSTFT is fixed.
C gMel is fixed while gSTFT is trainable.
D Both gMel and gSTFT are trainable.
trainable-STFT-Mel
├── conf
│ ├─model
│ │ ├─BC_ResNet.yaml
│ │ ├─BC_ResNet_ASR.yaml
│ │ ├─BC_ResNet_maskout.yaml
│ │ │
│ │ ├─Linearmodel.yaml
│ │ ├─Linearmodel_ASR.yaml
│ │ ├─Linearmodel_maskout.yaml
│ │ │
│ │
│ ├─ASR_config.yaml
│ └─KWS_config.yaml
│
├── models
│ ├─nnAudio_model.py
│ └─fastaudio_model.py
├── tasks
│ ├─speechcommand.py
│ ├─speechcommand_maskout.py
│ ├─Timit.py
│ ├─Timit_maskout.py
│ │
├──train_KWS_hydra.py
├──train_ASR_hydra.py
├──phonemics_dict
├──requirements.txt
confcontains the.yamlconfiguration files.modelscontains the model architectures.taskscontains the lightning modules for KWS and ASR.train_KWS_hydra.pyandtrain_ASR_hydra.pyare training script of KWS and ASR respectively.phonemics_dictis the phoneme labels provided in TIMIT which used for phoneme recognition.
Python 3.8.10 is required to run this repo.
You can install all required libraries at once via
pip install -r requirements.txtpython train_KWS_hydra.py python train_ASR_hydra.py Note:
- If this is your 1st time to train the model, you need to set
downloadsetting toTruevia
python train_KWS_hydra.py download=True- If you use CPU instead of GPU to train the model, set gpus to 0 via
python train_KWS_hydra.py gpus=0Default:
- nnAudio BC_ResNet model:
model=BC_ResNet - setting A (Both gMel and gSTFT are non-trainable):
model.spec_args.trainable_mel=Falsemodel.spec_args.trainable_STFT=False - 40 number of Mel bases:
model.spec_args.n_mels=40 - use 1 gpus
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.spec_args.trainable_mel=True,False model.spec_args.trainable_STFT=True,Falsepython train_KWS_hydra.py -m gpus=<arg> model=<arg> model.fastaudio.freeze=True,False model.spec_args.trainable=True,Falsemodel.fastaudio.freeze controls Mel basis functions:
model.fastaudio.freeze=Truerepresent mel non-trainablemodel.fastaudio.freeze=Falserepresent mel trainable
model.spec_args.trainable controls STFT:
model.spec_args.trainable=Truerepresent STFT trainablemodel.spec_args.trainable=Falserepresent STFT non-trainable
Note:
- simply replace
train_KWS_hydra.pywithtrain_ASR_hydra.pyfor ASR task.
python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.spec_args.n_mels=10,20,30,40 python train_KWS_hydra.py -m gpus=<arg> model=<arg> model.fastaudio.n_mels=10,20,30,40Note: simply replace train_KWS_hydra.py with train_ASR_hydra.py for ASR task.
python train_KWS_hydra.py gpus=<arg> model=<arg> model.maskout_start=<arg> model.maskout_end=<arg>Applicable model:
- KWS nnAudio BC_ResNet
- KWS nnAudio Simple
- ASR nnAudio Simple
Note: simply replace train_KWS_hydra.py with train_ASR_hydra.py for ASR task.
python train_KWS_hydra.py gpus=<arg> model=<arg> model.random_mel=TrueApplicable model:
- KWS nnAudio BC_ResNet
- ASR nnAudio BC_ResNet
- KWS nnAudio Simple
- ASR nnAudio Simple
Note: simply replace train_KWS_hydra.py with train_ASR_hydra.py for ASR task.