This repo contains the original Pytorch implementation of the following paper:
On the Surprising Efficacy of Distillation as an Alternative to Pre-Training Small Models
Sean Farhat, Deming Chen
University of Illinois, Urbana-Champaign
which will appear at the 5th Practical ML for Low Resource Settings (PML4LRS) Workshop at ICLR 2024.
Go into paths.json and edit the paths to point to where you'd like the datasets (data_path, syn_data_path) and checkpoints (cpt_path) to be saved and loaded.
To obtain the desired teachers, we must first train them on the task of interest. This can be achieved via 3 methods, depending on how we wish for the teacher to learn.
- Scratch (FR): The network is randomly initialized and fully trained on the task end-to-end.
- Linear Probed (LP): The network is initialized with a feature backbone pre-trained on ImageNet. Then, only it's task-specific head is trained on the task.
- Full Finetuning (FT): The network is initialized with a feature backbone pre-trained on ImageNet. Then, it is fully trained on the task end-to-end. It's task-specific head uses a higher learning rate than the body.
Note: For our experiments, we take the ImageNet pre-trained weights from Pytorch's model hub.
To create these teachers, edit and use the appropriate scripts/train_(fr|lp|ft).sh scripts.
All possible command line arguments can be found by running python train_(fr|lp|ft).py --help.
These will create and save the best and last model checkpoints in <data_path>/<model>_<fr|lp|ft>_<optimizer>
Edit and run scripts/distill.sh. In this script, we have several options to control the assistance process.
--datasetchooses the task. Options:cifar100,cifar10,mit_indoor,cub_2011,caltech101,dtd--teacher_modelchooses which teacher model we use. Step 1 must be completed for this script to find the desired model. Options:resnet50,vit-b-16--teacher_initchooses the initialization of the teacher. Options:fr,lp,ft--student_modelchooses the student model which is initialized randomly. Options:mobilenetv2orresnet18--distillchooses the distillation algorithm. Options:align_uniform,crd,kd,srrl
All possible command line arguments can be found by running python distill.py --help.
First, we have to generate the synthetic data. To do this, edit and run scripts/generate.sh.
This will save the synthetic data in the syn_data_path from paths.json.
Then, edit and run scripts/gen_distill.sh. It works similar to distill.sh, with the addition of the following options:
All possible command line arguments can be found by running python generated_distill.py --help.
--synset_sizechooses how much of the synthetic dataset we wish to use as a fraction of the training set size. Options:1x,2x(Note: enough synthetic images must be generated for these to work correctly.)--augenables image augmentations--aug_modechooses whether to apply a Singular or Multiple augmentations. Options:S,M
Weights and Biases (wandb) integration is included for all scripts above. Assuming you have wandb set up on your machine, simply add the --logging flag to each script.
We have included a convenient --timing flag for all scripts that will run the task for one epoch and report how long it took.
@misc{farhat2024surprising,
title={On the Surprising Efficacy of Distillation as an Alternative to Pre-Training Small Models},
author={Sean Farhat and Deming Chen},
year={2024},
eprint={2404.03263},
archivePrefix={arXiv},
primaryClass={cs.LG}
}