|
| 1 | +<font size=4><b>Train Wide-ResNet, Shake-Shake and ShakeDrop models on CIFAR-10 |
| 2 | +and CIFAR-100 dataset with AutoAugment.</b></font> |
| 3 | + |
| 4 | +The CIFAR-10/CIFAR-100 data can be downloaded from: |
| 5 | +https://www.cs.toronto.edu/~kriz/cifar.html. |
| 6 | + |
| 7 | +The code replicates the results from Tables 1 and 2 on CIFAR-10/100 with the |
| 8 | +following models: Wide-ResNet-28-10, Shake-Shake (26 2x32d), Shake-Shake (26 |
| 9 | +2x96d) and PyramidNet+ShakeDrop. |
| 10 | + |
| 11 | +<b>Related papers:</b> |
| 12 | + |
| 13 | +AutoAugment: Learning Augmentation Policies from Data |
| 14 | + |
| 15 | +https://arxiv.org/abs/1805.09501 |
| 16 | + |
| 17 | +Wide Residual Networks |
| 18 | + |
| 19 | +https://arxiv.org/abs/1605.07146 |
| 20 | + |
| 21 | +Shake-Shake regularization |
| 22 | + |
| 23 | +https://arxiv.org/abs/1705.07485 |
| 24 | + |
| 25 | +ShakeDrop regularization |
| 26 | + |
| 27 | +https://arxiv.org/abs/1802.02375 |
| 28 | + |
| 29 | +<b>Settings:</b> |
| 30 | + |
| 31 | +CIFAR-10 Model | Learning Rate | Weight Decay | Num. Epochs | Batch Size |
| 32 | +---------------------- | ------------- | ------------ | ----------- | ---------- |
| 33 | +Wide-ResNet-28-10 | 0.1 | 5e-4 | 200 | 128 |
| 34 | +Shake-Shake (26 2x32d) | 0.01 | 1e-3 | 1800 | 128 |
| 35 | +Shake-Shake (26 2x96d) | 0.01 | 1e-3 | 1800 | 128 |
| 36 | +PyramidNet + ShakeDrop | 0.05 | 5e-5 | 1800 | 64 |
| 37 | + |
| 38 | +<b>Prerequisite:</b> |
| 39 | + |
| 40 | +1. Install TensorFlow. |
| 41 | + |
| 42 | +2. Download CIFAR-10/CIFAR-100 dataset. |
| 43 | + |
| 44 | +```shell |
| 45 | +curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz |
| 46 | +curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz |
| 47 | +``` |
| 48 | + |
| 49 | +<b>How to run:</b> |
| 50 | + |
| 51 | +```shell |
| 52 | +# cd to the your workspace. |
| 53 | +# Specify the directory where dataset is located using the data_path flag. |
| 54 | +# Note: User can split samples from training set into the eval set by changing train_size and validation_size. |
| 55 | + |
| 56 | +# For example, to train the Wide-ResNet-28-10 model on a GPU. |
| 57 | +python train_cifar.py --model_name=wrn \ |
| 58 | + --checkpoint_dir=/tmp/training \ |
| 59 | + --data_path=/tmp/data \ |
| 60 | + --dataset='cifar10' \ |
| 61 | + --use_cpu=0 |
| 62 | +``` |
| 63 | + |
| 64 | +## Contact for Issues |
| 65 | + |
| 66 | +* Barret Zoph, @barretzoph <[email protected]> |
| 67 | +* Ekin Dogus Cubuk, <[email protected]> |
0 commit comments