|
1 | 1 | from yacs.config import CfgNode |
2 | | -import os.path as osp |
3 | | - |
4 | | -_C = CfgNode() |
5 | | - |
6 | | -_C.MODEL = CfgNode() |
7 | | -_C.MODEL.ARCHITECTURE = "NORM" # GOOD, NORM |
8 | | -_C.MODEL.DIM_G = 128 # generator dimensionality |
9 | | -_C.MODEL.DIM_D = 128 # Critic dimensionality |
10 | | -_C.MODEL.DIM = 64 # DIM for good generator and discriminator |
11 | | -_C.MODEL.HASH_DIM = 64 |
12 | | -_C.MODEL.PRETRAINED_MODEL_PATH = "" |
13 | | -_C.MODEL.ALEXNET_PRETRAINED_MODEL_PATH = "pretrained_models/reference_pretrain.npy" |
14 | | - |
15 | | -_C.DATA = CfgNode() |
16 | | -_C.DATA.USE_DATASET = "cifar10" # "cifar10", "nuswide81", "coco" |
17 | | -_C.DATA.LIST_ROOT = "./data/cifar10" |
18 | | -_C.DATA.DATA_ROOT = "./data_list/cifar10" |
19 | | -_C.DATA.LABEL_DIM = 10 |
20 | | -_C.DATA.DB_SIZE = 54000 |
21 | | -_C.DATA.TEST_SIZE = 1000 |
22 | | -_C.DATA.WIDTH_HEIGHT = 32 |
23 | | -_C.DATA.OUTPUT_DIM = 32 * 32 * 3 # Number of pixels (32*32*3) |
24 | | -_C.DATA.MAP_R = 54000 |
25 | | - |
26 | | -_C.DATA.OUTPUT_DIR = "./output/cifar10_step_1" |
27 | | -_C.DATA.IMAGE_DIR = osp.join(_C.DATA.OUTPUT_DIR, "images") |
28 | | -_C.DATA.MODEL_DIR = osp.join(_C.DATA.OUTPUT_DIR, "models") |
29 | | -_C.DATA.LOG_DIR = osp.join(_C.DATA.OUTPUT_DIR, "logs") |
30 | | - |
31 | | -_C.TRAIN = CfgNode() |
32 | | -_C.TRAIN.BATCH_SIZE = 64 |
33 | | -_C.TRAIN.ITERS = 100000 |
34 | | -_C.TRAIN.CROSS_ENTROPY_ALPHA = 5 |
35 | | -_C.TRAIN.LR = 1e-4 # Initial learning rate |
36 | | -_C.TRAIN.G_LR = 1e-4 # 1e-4 |
37 | | -_C.TRAIN.DECAY = True # Whether to decay LR over learning |
38 | | -_C.TRAIN.N_CRITIC = 5 # Critic steps per generator steps |
39 | | -_C.TRAIN.EVAL_FREQUENCY = 20000 # How frequently to evaluate and save model |
40 | | -_C.TRAIN.RUNTIME_MEASURE_FREQUENCY = 20 # How frequently to evaluate and save model |
41 | | -_C.TRAIN.SAMPLE_FREQUENCY = 1000 # How frequently to evaluate and save model |
42 | | -_C.TRAIN.ACGAN_SCALE = 1.0 |
43 | | -_C.TRAIN.ACGAN_SCALE_G = 0.1 |
44 | | -_C.TRAIN.WGAN_SCALE = 1.0 |
45 | | -_C.TRAIN.WGAN_SCALE_G = 1.0 |
46 | | -_C.TRAIN.NORMED_CROSS_ENTROPY = True |
47 | | -_C.TRAIN.FAKE_RATIO = 1.0 |
48 | | - |
49 | | -config = _C |
| 2 | +import os |
| 3 | + |
| 4 | +config = CfgNode() |
| 5 | + |
| 6 | +config.MODEL = CfgNode() |
| 7 | +config.MODEL.DIM_G = 128 # generator dimensionality |
| 8 | +config.MODEL.DIM_D = 128 # Critic dimensionality |
| 9 | +config.MODEL.DIM = 64 # DIM for good generator and discriminator |
| 10 | +config.MODEL.HASH_DIM = 64 |
| 11 | +config.MODEL.G_ARCHITECTURE = "NORM" # GOOD, NORM |
| 12 | +config.MODEL.D_ARCHITECTURE = "NORM" # GOOD, NORM, ALEXNET |
| 13 | +config.MODEL.G_PRETRAINED_MODEL_PATH = "" |
| 14 | +config.MODEL.D_PRETRAINED_MODEL_PATH = "" |
| 15 | +# TODO: merge ALEXNET_PRETRAINED_MODEL_PATH and D_PRETRAINED_MODEL_PATH |
| 16 | +config.MODEL.ALEXNET_PRETRAINED_MODEL_PATH = "./pretrained_models/reference_pretrain.npy" |
| 17 | + |
| 18 | +config.DATA = CfgNode() |
| 19 | +config.DATA.USE_DATASET = "cifar10" # "cifar10", "nuswide81", "coco" |
| 20 | +config.DATA.LIST_ROOT = "./data/cifar10" |
| 21 | +config.DATA.DATA_ROOT = "./data_list/cifar10" |
| 22 | +config.DATA.LABEL_DIM = 10 |
| 23 | +config.DATA.DB_SIZE = 54000 |
| 24 | +config.DATA.TEST_SIZE = 1000 |
| 25 | +config.DATA.WIDTH_HEIGHT = 32 |
| 26 | +config.DATA.OUTPUT_DIM = 3 * (config.DATA.WIDTH_HEIGHT ** 2) # Number of pixels (32*32*3) |
| 27 | +config.DATA.MAP_R = 54000 |
| 28 | +config.DATA.OUTPUT_DIR = "./output/cifar10_step_1" |
| 29 | +config.DATA.IMAGE_DIR = os.path.join(config.DATA.OUTPUT_DIR, "images") |
| 30 | +config.DATA.MODEL_DIR = os.path.join(config.DATA.OUTPUT_DIR, "models") |
| 31 | +config.DATA.LOG_DIR = os.path.join(config.DATA.OUTPUT_DIR, "logs") |
| 32 | + |
| 33 | +config.TRAIN = CfgNode() |
| 34 | +config.TRAIN.EVALUATE_MODE = False |
| 35 | +config.TRAIN.BATCH_SIZE = 64 |
| 36 | +config.TRAIN.ITERS = 100000 |
| 37 | +config.TRAIN.CROSS_ENTROPY_ALPHA = 5 |
| 38 | +config.TRAIN.LR = 1e-4 # Initial learning rate |
| 39 | +config.TRAIN.G_LR = 1e-4 # 1e-4 |
| 40 | +config.TRAIN.DECAY = True # Whether to decay LR over learning |
| 41 | +config.TRAIN.N_CRITIC = 5 # Critic steps per generator steps |
| 42 | +config.TRAIN.EVAL_FREQUENCY = 20000 # How frequently to evaluate and save model |
| 43 | +config.TRAIN.CHECKPOINT_FREQUENCY = 2000 # How frequently to evaluate and save model |
| 44 | +config.TRAIN.RUNTIME_MEASURE_FREQUENCY = 200 # How frequently to evaluate and save model |
| 45 | +config.TRAIN.SAMPLE_FREQUENCY = 1000 # How frequently to evaluate and save model |
| 46 | +config.TRAIN.ACGAN_SCALE = 1.0 |
| 47 | +config.TRAIN.ACGAN_SCALE_FAKE = 1.0 |
| 48 | +config.TRAIN.WGAN_SCALE = 1.0 |
| 49 | +config.TRAIN.WGAN_SCALE_GP = 10.0 |
| 50 | +config.TRAIN.ACGAN_SCALE_G = 0.1 |
| 51 | +config.TRAIN.WGAN_SCALE_G = 1.0 |
| 52 | +config.TRAIN.NORMED_CROSS_ENTROPY = True |
| 53 | +config.TRAIN.FAKE_RATIO = 1.0 |
| 54 | + |
| 55 | + |
| 56 | +def update_and_inference_config(cfg_file): |
| 57 | + config.merge_from_file(cfg_file) |
| 58 | + |
| 59 | + config.DATA.IMAGE_DIR = os.path.join(config.DATA.OUTPUT_DIR, "images") |
| 60 | + config.DATA.MODEL_DIR = os.path.join(config.DATA.OUTPUT_DIR, "models") |
| 61 | + config.DATA.LOG_DIR = os.path.join(config.DATA.OUTPUT_DIR, "logs") |
| 62 | + config.DATA.OUTPUT_DIM = 3 * (config.DATA.WIDTH_HEIGHT ** 2) # Number of pixels (32*32*3) |
| 63 | + |
| 64 | + os.makedirs(config.DATA.IMAGE_DIR, exist_ok=True) |
| 65 | + os.makedirs(config.DATA.MODEL_DIR, exist_ok=True) |
| 66 | + os.makedirs(config.DATA.LOG_DIR, exist_ok=True) |
| 67 | + |
| 68 | + config.freeze() |
| 69 | + return config |
0 commit comments