Advancing Spatio-Temporal Processing in Spiking Neural Networks through Adaptation
To set up the required environment, run:
conda env create -f environment.ymlTo start an experiment, use:
python run.py experiment=<experiment_name> ++logdir=path/to/my/logdir ++datadir=path/to/my/datadirNotes:
datadiris mandatory and should contain the datasets.- For SHD and SSC, data is downloaded automatically if not found at
datadir/SHDWrapper. - BSD and oscillation toy task, datasets are created on the fly, so
datadircan point to an empty directory. - Results are stored in a local
resultsfolder unlessresultdiris specified. <experiment_name>refers to configurations in./config/experiment/.
We use Hydra for configuration management. To override parameters, use the ++ syntax. For example, to change the number of training epochs:
python run.py experiment=SHD_SE_adLIF_small ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir ++n_epochs=10For the BSD task with a different number of classes (Figure 6b):
python run.py experiment=BSD_SE_adLIF ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir ++dataset.num_classes=10To start an audio compression experiment, use:
python run_compress.py experiment=<experiment_name> ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir- SE-adLIF:
compress_libri_SE_adLIF - EF-adLIF:
compress_libri_EF_adLIF - LIF:
compress_libri_LIF
Model checkpoints for each configuration are available at checkpoints.
To generate wave files from a checkpoint:
generate_waves.py ckpt_path=/path/to/ckpt/example.ckpt source_wave_path=/path/to/libritts/location/ pred_wave_path=/path/to/prediction/ encoder_only=$encoder_flag$encoder_flag:trueorfalse.source_wave_pathcan be a single.wavfile or a directory containing.wavfiles.- If no valid
.wavfiles exist, the clean test-set from LibriTTS (~9h of audio) is used.
Use evaluate_metrics.py to compute SI-SNR or Visqol:
evaluate_metrics.py metric=$metric source_wave_path=path/to/source/waves pred_wave_path=path/to/model/predictions$metriccan besi_snrorvisqol.- Note: Visqol must be compiled manually following these instructions. Additionally, the project requires either
gcc-9/g++-9orgcc-10/g++-10. Set the compiler using:
export CC=gcc-9 CXX=g++-9Furthermore, Visqol relies on Bazel but references an outdated HTTP resource (Armadillo) in its WORKSPACE file. The ressource has been moved here. You should modify the WORKSPACE file to reference your local copy as instructed here.
Global parameters (e.g., device: 'cpu', cuda:0) can be set in config/main.yaml. These settings are used by PyTorch Lightning’s SingleDeviceStrategy.
For variable-length sequences (e.g., SHD, SSC), a custom masking procedure is used:
- Data vector: Contains actual data, padded with zeros.
- Block index (
block_idx): Indicates valid data (1s) and padding (0s). - Target vector: Maps indices to corresponding labels.
data vector: |1011010100101001010000000000000|
|-----data---------|--padding---|
---> time
block_idx: |1111111111111111111100000000000|
target: [-1, 3]
Explanation:
- Block
0has target-1(ignored). - Block
1has target class3.
data vector: |1 0 1 1 0 0 1 0 0 0 0 0 0|
|-----data---|--padding---|
---> time
block_idx: |1 2 3 4 5 6 7 0 0 0 0 0 0|
target: [-1, 4, 3, 1, 3, 4, 6, 3]
Explanation:
- Multiple blocks (
1-7) have corresponding target labels. - Padding (
0s) is ignored during loss computation.
Using this method, per-block predictions can be efficiently gathered using torch.scatter_reduce, ignoring padded time steps.
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.