The repo comes from this
Representations for Few-Shot Learning (RFS). This repo covers the implementation of the following paper:
"Rethinking few-shot image classification: a good embedding is all you need?" Paper, Project Page
If you find this repo useful for your research, please consider citing the paper
@article{tian2020rethink,
title={Rethinking few-shot image classification: a good embedding is all you need?},
author={Tian, Yonglong and Wang, Yue and Krishnan, Dilip and Tenenbaum, Joshua B and Isola, Phillip},
journal={arXiv preprint arXiv:2003.11539},
year={2020}
}
This repo was tested with Ubuntu 16.04.5 LTS, Python 3.5, PyTorch 0.4.0, and CUDA 9.0. However, it should be compatible with recent PyTorch versions >=0.4.0
The data we used here is preprocessed by the repo of MetaOptNet, but we have renamed the file. Our version of data can be downloaded from here:
Exemplar commands for running the code can be found in scripts/run.sh.
For unuspervised learning methods CMC and MoCo, please refer to the CMC repo.
For any questions, please contact:
Yonglong Tian ([email protected])
Yue Wang ([email protected])
Part of the code for distillation is from RepDistiller repo.
In order to apply our medical images for few-shot learning on this repo, we modified the codes in some cases. The usage is shown as the followings.
(1)Generate the data with the specified format.
Firstly, we need to generate the data with the specified format .pickle whose type is dict:
- data['data']: imgs (type: numpy.array, (batch_size, width, height, channels))
- data['labels']: labels (type: list)
As utils/create_dataset.py shown, we split the data as train.pickle, val.pickle, test.pickle, trainval.pickle, the structure of the origin medical images should be like this:
directory/
├── class_x
│ ├── xxx.tif
│ ├── xxy.tif
│ └── ...
└── class_y
├── 123.tif
├── nsdf3.tif
└── ...
└── asd932_.tif
In the function load_data, the parameter numPerClass denotes the number of imgs each class sampling in the original medical images.
(2)Choose the correct Dataset and transform for training:
In train_supervised.py, train_distillation.py, eval_fewshot.py, we should set the value customDataset of dataset and set the value n-ways in args.parser.
Note the n must be less than the number of total classes N in train_dataset.
For medical images with different sizes, we should modify the transform_E in dataset/transform_cfg.py.
(3)Set the appropriate top-k.
In the function train, validate and accuracy, we should set the topk=(1, k), the k must be less than the number of total classes N in train_dataset. Obviously, when k == N, Top-k acc must be 100% in the logs of the trainning phase, as Top-1 acc gives us accuracy in the usual sense.
(4)Set run.sh to train
Before runing the program, we should create the dirs, checkpoints and tensorboardlogs for train_supervised.py, dis_checkpoints and dis_tensorboardlogs for train_distillation.py.
run.sh: See code for specific parameters.
# ======================
# exampler commands on custom datasets
# ======================
# supervised pre-training
python train_supervised.py --trial pretrain --model_path ./checkpoints --tb_path ./tensorboardlogs --data_root ./data
# distillation
# setting '-a 1.0' should give simimlar performance
# python train_distillation.py -r 0.5 -a 0.5 --path_t ./checkpoints/resnet12_customDataset_lr_0.05_decay_0.0005_trans_A_trial_pretrain/resnet12_last.pth --trial born1 --model_path ./dis_checkpoints --tb_path ./dis_tensorboardlogs --data_root ./data/
# evaluation
# python eval_fewshot.py --model_path ./dis_checkpoints/S:resnet12_T:resnet12_customDataset_kd_r:0.5_a:0.5_b:0_trans_A_born1/resnet12_last.pth --data_root ./data/customDataset/