Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 89fc206

Browse files
committed
refactor dataloder
1 parent 63294dd commit 89fc206

File tree

8 files changed

+70
-152
lines changed

8 files changed

+70
-152
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ If you need run on NUSWIDE_81 and COCO, we recommend you to follow https://githu
3333

3434
- [ ] Pretrain model of Alexnet
3535
- [ ] pretrained G model
36-
- [ ] resume training
3736
- [ ] eval frequence & eval at last iter
38-
- [ ] training longger
3937
- [ ] refactor all
4038
- [ ] use config instead of constant
4139
- [ ] use no split
4240
- [ ] evaluate mode
4341
- [ ] output dir which contains images, models, logs
4442
- [ ] mkdir automatically
43+
- [ ] training longger
44+
- [ ] resume training
4545
- [ ] rerun all process on a fresh machine
4646

4747
Configuration for th models is specified in a list of constants at the top of
Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -87,43 +87,65 @@ def get_labels(self):
8787
return np.asarray(self._label)
8888

8989

90-
def data_generator(batch_size, width_height, file_name):
91-
_dataset = Dataset(file_name, True, width_height)
92-
93-
def get_epoch():
94-
95-
_index_in_epoch = 0
96-
_perm = np.arange(_dataset.n_samples)
97-
np.random.shuffle(_perm)
98-
for _ in range(int(math.ceil(_dataset.n_samples / batch_size))):
99-
start = _index_in_epoch
100-
_index_in_epoch += batch_size
101-
# finish one epoch
102-
if _index_in_epoch > _dataset.n_samples:
103-
data, label = _dataset.data(_perm[start:])
104-
data1, label1 = _dataset.data(
105-
_perm[:_index_in_epoch - _dataset.n_samples])
106-
data = np.concatenate([data, data1], axis=0)
107-
label = np.concatenate([label, label1], axis=0)
108-
else:
109-
end = _index_in_epoch
110-
data, label = _dataset.data(_perm[start:end])
111-
112-
# n*h*w*c -> n*c*h*w
113-
data = np.transpose(data, (0, 3, 1, 2))
114-
# bgr -> rgb
115-
data = data[:, ::-1, :, :]
116-
data = np.reshape(data, (batch_size, -1))
117-
yield (data, label)
118-
119-
return get_epoch
120-
121-
122-
def load_train(batch_size, width_height, data_root):
123-
return [data_generator(batch_size, width_height, os.path.join(data_root, split + '.txt'))
124-
for split in ["train", "database_nolabel", "test"]]
125-
126-
def load_val(batch_size, width_height, data_root):
127-
return [data_generator(batch_size, width_height, os.path.join(data_root, split + '.txt'))
128-
for split in ["database", "test.txt"]]
90+
class Dataloader(object):
91+
92+
def __init__(self, batch_size, width_height, data_root):
93+
self.batch_size = batch_size
94+
self.width_height = width_height
95+
self.data_root = data_root
96+
97+
def data_generator(self, split):
98+
file_name = os.path.join(self.data_root, split + '.txt')
99+
_dataset = Dataset(file_name, True, self.width_height)
100+
101+
def get_epoch():
102+
103+
_index_in_epoch = 0
104+
_perm = np.arange(_dataset.n_samples)
105+
np.random.shuffle(_perm)
106+
for _ in range(int(math.ceil(_dataset.n_samples / self.batch_size))):
107+
start = _index_in_epoch
108+
_index_in_epoch += self.batch_size
109+
# finish one epoch
110+
if _index_in_epoch > _dataset.n_samples:
111+
data, label = _dataset.data(_perm[start:])
112+
data1, label1 = _dataset.data(
113+
_perm[:_index_in_epoch - _dataset.n_samples])
114+
data = np.concatenate([data, data1], axis=0)
115+
label = np.concatenate([label, label1], axis=0)
116+
else:
117+
end = _index_in_epoch
118+
data, label = _dataset.data(_perm[start:end])
119+
120+
# n*h*w*c -> n*c*h*w
121+
data = np.transpose(data, (0, 3, 1, 2))
122+
# bgr -> rgb
123+
data = data[:, ::-1, :, :]
124+
data = np.reshape(data, (self.batch_size, -1))
125+
yield (data, label)
126+
127+
return get_epoch
129128

129+
@property
130+
def train_gen(self):
131+
return self.data_generator('train')
132+
133+
@property
134+
def test_gen(self):
135+
return self.data_generator('test')
136+
137+
@property
138+
def db_gen(self):
139+
return self.data_generator('database')
140+
141+
@property
142+
def unlabeled_db_gen(self):
143+
return self.data_generator('database_nolabel')
144+
145+
@staticmethod
146+
def inf_gen(gen):
147+
def generator():
148+
while True:
149+
for images_iter_, labels_iter_ in gen():
150+
return images_iter_, labels_iter_
151+
return generator

dataloader/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

dataloader/cifar10.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

dataloader/coco.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

dataloader/nuswide81.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

main.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tensorflow as tf
2121
from tensorflow.python.client import device_lib
2222

23-
import dataloader
23+
from dataloader import Dataloader
2424
import tflib as lib
2525
import tflib.plot
2626
import tflib.save_images
@@ -33,9 +33,9 @@
3333

3434

3535
def main(cfg):
36-
dataset = dataloader.__dict__[cfg.DATA.USE_DATASET]
3736
DEVICES = [x.name for x in device_lib.list_local_devices()
3837
if x.device_type == 'GPU']
38+
dataloader = Dataloader(cfg.DATA.BATCH_SIZE, cfg.DATA.WIDTH_HEIGHT, cfg.DATA.DATA_ROOT)
3939

4040
configProto = tf.ConfigProto()
4141
configProto.gpu_options.allow_growth = True
@@ -309,9 +309,8 @@ def generate_image(frame):
309309
lib.save_images.save_images(samples.reshape((100, 3, cfg.DATA.WIDTH_HEIGHT, cfg.DATA.WIDTH_HEIGHT)),
310310
'{}/samples_{}.png'.format(cfg.DATA.IMAGE_DIR, frame))
311311

312-
train_gen, unlabel_train_gen, dev_gen = dataset.load(cfg.TRAIN.BATCH_SIZE, cfg.DATA.WIDTH_HEIGHT)
313-
gen = util.inf_gen(train_gen)
314-
unlabel_gen = util.inf_gen(unlabel_train_gen)
312+
gen = dataloader.inf_gen(dataloader.train_gen)
313+
unlabel_gen = dataloader.inf_gen(dataloader.unlabeled_db_gen)
315314

316315
util.print_param_size(gen_gv, disc_gv)
317316

@@ -378,19 +377,18 @@ def generate_image(frame):
378377

379378
# calculate mAP score w.r.t all db data every 10000 config.TRAIN.ITERS
380379
if (iteration + 1) % 10000 == 0:
381-
_db_gen, _test_gen = dataset.load_val(cfg.TRAIN.BATCH_SIZE, cfg.DATA.WIDTH_HEIGHT)
382380
db_output = []
383381
db_labels = []
384382
test_output = []
385383
test_labels = []
386-
for images, _labels in _test_gen():
387-
_disc_acgan_output, __cost = session.run([disc_real_acgan, disc_real_acgan_cost],
384+
for images, _labels in dataloader.test_gen():
385+
_disc_acgan_output, _ = session.run([disc_real_acgan, disc_real_acgan_cost],
388386
feed_dict={all_real_data_int: images,
389387
all_real_labels: _labels})
390388
test_output.append(_disc_acgan_output)
391389
test_labels.append(_labels)
392390

393-
for images, _labels in _db_gen():
391+
for images, _labels in dataloader.db_gen():
394392
_disc_acgan_output, _ = session.run([disc_real_acgan, disc_real_acgan_cost],
395393
feed_dict={all_real_data_int: images, all_real_labels: _labels})
396394
db_output.append(_disc_acgan_output)

util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88
import locale
99

1010

11-
def inf_gen(gen):
12-
def generator():
13-
while True:
14-
for images_iter_, labels_iter_ in gen():
15-
return images_iter_, labels_iter_
16-
return generator
17-
18-
1911
# compute param size
2012
def print_param_size(gen_gv, disc_gv):
2113
print("computing param size")

0 commit comments

Comments
 (0)