40
40
from __future__ import division
41
41
from __future__ import print_function
42
42
43
-
44
43
import tensorflow as tf
45
44
46
45
FLAGS = tf .app .flags .FLAGS
52
51
tf .app .flags .DEFINE_integer ('num_preprocess_threads' , 4 ,
53
52
"""Number of preprocessing threads per tower. """
54
53
"""Please make this a multiple of 4.""" )
54
+ tf .app .flags .DEFINE_integer ('num_readers' , 4 ,
55
+ """Number of parallel readers during train.""" )
55
56
56
57
# Images are preprocessed asynchronously using multiple threads specifed by
57
58
# --num_preprocss_threads and the resulting processed images are stored in a
@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
97
98
with tf .device ('/cpu:0' ):
98
99
images , labels = batch_inputs (
99
100
dataset , batch_size , train = False ,
100
- num_preprocess_threads = num_preprocess_threads )
101
+ num_preprocess_threads = num_preprocess_threads ,
102
+ num_readers = 1 )
101
103
102
104
return images , labels
103
105
@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
130
132
with tf .device ('/cpu:0' ):
131
133
images , labels = batch_inputs (
132
134
dataset , batch_size , train = True ,
133
- num_preprocess_threads = num_preprocess_threads )
135
+ num_preprocess_threads = num_preprocess_threads ,
136
+ num_readers = FLAGS .num_readers )
134
137
return images , labels
135
138
136
139
@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
401
404
return features ['image/encoded' ], label , bbox , features ['image/class/text' ]
402
405
403
406
404
- def batch_inputs (dataset , batch_size , train , num_preprocess_threads = None ):
407
+ def batch_inputs (dataset , batch_size , train , num_preprocess_threads = None ,
408
+ num_readers = 1 ):
405
409
"""Contruct batches of training or evaluation examples from the image dataset.
406
410
407
411
Args:
@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
410
414
batch_size: integer
411
415
train: boolean
412
416
num_preprocess_threads: integer, total number of preprocessing threads
417
+ num_readers: integer, number of parallel readers
413
418
414
419
Returns:
415
420
images: 4-D float Tensor of a batch of images
@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
422
427
data_files = dataset .data_files ()
423
428
if data_files is None :
424
429
raise ValueError ('No data files found for this dataset' )
425
- filename_queue = tf .train .string_input_producer (data_files , capacity = 16 )
426
430
431
+ # Create filename_queue
432
+ if train :
433
+ filename_queue = tf .train .string_input_producer (data_files ,
434
+ shuffle = True ,
435
+ capacity = 16 )
436
+ else :
437
+ filename_queue = tf .train .string_input_producer (data_files ,
438
+ shuffle = False ,
439
+ capacity = 1 )
427
440
if num_preprocess_threads is None :
428
441
num_preprocess_threads = FLAGS .num_preprocess_threads
429
442
430
443
if num_preprocess_threads % 4 :
431
444
raise ValueError ('Please make num_preprocess_threads a multiple '
432
445
'of 4 (%d % 4 != 0).' , num_preprocess_threads )
433
- # Create a subgraph with its own reader (but sharing the
434
- # filename_queue) for each preprocessing thread.
435
- images_and_labels = []
436
- for thread_id in range (num_preprocess_threads ):
437
- reader = dataset .reader ()
438
- _ , example_serialized = reader .read (filename_queue )
439
446
440
- # Parse a serialized Example proto to extract the image and metadata.
441
- image_buffer , label_index , bbox , _ = parse_example_proto (
442
- example_serialized )
443
- image = image_preprocessing ( image_buffer , bbox , train , thread_id )
444
- images_and_labels . append ([ image , label_index ] )
447
+ if num_readers is None :
448
+ num_readers = FLAGS . num_readers
449
+
450
+ if num_readers < 1 :
451
+ raise ValueError ( 'Please make num_readers at least 1' )
445
452
446
453
# Approximate number of examples per shard.
447
454
examples_per_shard = 1024
@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
451
458
# The default input_queue_memory_factor is 16 implying a shuffling queue
452
459
# size: examples_per_shard * 16 * 1MB = 17.6GB
453
460
min_queue_examples = examples_per_shard * FLAGS .input_queue_memory_factor
454
-
455
- # Create a queue that produces the examples in batches after shuffling.
456
461
if train :
457
- images , label_index_batch = tf .train .shuffle_batch_join (
458
- images_and_labels ,
459
- batch_size = batch_size ,
462
+ examples_queue = tf .RandomShuffleQueue (
460
463
capacity = min_queue_examples + 3 * batch_size ,
461
- min_after_dequeue = min_queue_examples )
464
+ min_after_dequeue = min_queue_examples ,
465
+ dtypes = [tf .string ])
466
+ else :
467
+ examples_queue = tf .FIFOQueue (
468
+ capacity = examples_per_shard + 3 * batch_size ,
469
+ dtypes = [tf .string ])
470
+
471
+ # Create multiple readers to populate the queue of examples.
472
+ if num_readers > 1 :
473
+ enqueue_ops = []
474
+ for _ in range (num_readers ):
475
+ reader = dataset .reader ()
476
+ _ , value = reader .read (filename_queue )
477
+ enqueue_ops .append (examples_queue .enqueue ([value ]))
478
+
479
+ tf .train .queue_runner .add_queue_runner (
480
+ tf .train .queue_runner .QueueRunner (examples_queue , enqueue_ops ))
481
+ example_serialized = examples_queue .dequeue ()
462
482
else :
463
- images , label_index_batch = tf .train .batch_join (
464
- images_and_labels ,
465
- batch_size = batch_size ,
466
- capacity = min_queue_examples + 3 * batch_size )
483
+ reader = dataset .reader ()
484
+ _ , example_serialized = reader .read (filename_queue )
485
+
486
+ images_and_labels = []
487
+ for thread_id in range (num_preprocess_threads ):
488
+ # Parse a serialized Example proto to extract the image and metadata.
489
+ image_buffer , label_index , bbox , _ = parse_example_proto (
490
+ example_serialized )
491
+ image = image_preprocessing (image_buffer , bbox , train , thread_id )
492
+ images_and_labels .append ([image , label_index ])
493
+
494
+ images , label_index_batch = tf .train .batch_join (
495
+ images_and_labels ,
496
+ batch_size = batch_size ,
497
+ capacity = 2 * num_preprocess_threads * batch_size )
467
498
468
499
# Reshape images into these desired dimensions.
469
500
height = FLAGS .image_size
0 commit comments