Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5d7612c

Browse files
jmchen-gmrry
authored andcommitted
Improve image processing (tensorflow#45)
* improve image processing performance for Inception.
1 parent 84b58a6 commit 5d7612c

File tree

1 file changed

+57
-26
lines changed

1 file changed

+57
-26
lines changed

inception/inception/image_processing.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from __future__ import division
4141
from __future__ import print_function
4242

43-
4443
import tensorflow as tf
4544

4645
FLAGS = tf.app.flags.FLAGS
@@ -52,6 +51,8 @@
5251
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
5352
"""Number of preprocessing threads per tower. """
5453
"""Please make this a multiple of 4.""")
54+
tf.app.flags.DEFINE_integer('num_readers', 4,
55+
"""Number of parallel readers during train.""")
5556

5657
# Images are preprocessed asynchronously using multiple threads specifed by
5758
# --num_preprocss_threads and the resulting processed images are stored in a
@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
9798
with tf.device('/cpu:0'):
9899
images, labels = batch_inputs(
99100
dataset, batch_size, train=False,
100-
num_preprocess_threads=num_preprocess_threads)
101+
num_preprocess_threads=num_preprocess_threads,
102+
num_readers=1)
101103

102104
return images, labels
103105

@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
130132
with tf.device('/cpu:0'):
131133
images, labels = batch_inputs(
132134
dataset, batch_size, train=True,
133-
num_preprocess_threads=num_preprocess_threads)
135+
num_preprocess_threads=num_preprocess_threads,
136+
num_readers=FLAGS.num_readers)
134137
return images, labels
135138

136139

@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
401404
return features['image/encoded'], label, bbox, features['image/class/text']
402405

403406

404-
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
407+
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
408+
num_readers=1):
405409
"""Contruct batches of training or evaluation examples from the image dataset.
406410
407411
Args:
@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
410414
batch_size: integer
411415
train: boolean
412416
num_preprocess_threads: integer, total number of preprocessing threads
417+
num_readers: integer, number of parallel readers
413418
414419
Returns:
415420
images: 4-D float Tensor of a batch of images
@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
422427
data_files = dataset.data_files()
423428
if data_files is None:
424429
raise ValueError('No data files found for this dataset')
425-
filename_queue = tf.train.string_input_producer(data_files, capacity=16)
426430

431+
# Create filename_queue
432+
if train:
433+
filename_queue = tf.train.string_input_producer(data_files,
434+
shuffle=True,
435+
capacity=16)
436+
else:
437+
filename_queue = tf.train.string_input_producer(data_files,
438+
shuffle=False,
439+
capacity=1)
427440
if num_preprocess_threads is None:
428441
num_preprocess_threads = FLAGS.num_preprocess_threads
429442

430443
if num_preprocess_threads % 4:
431444
raise ValueError('Please make num_preprocess_threads a multiple '
432445
'of 4 (%d % 4 != 0).', num_preprocess_threads)
433-
# Create a subgraph with its own reader (but sharing the
434-
# filename_queue) for each preprocessing thread.
435-
images_and_labels = []
436-
for thread_id in range(num_preprocess_threads):
437-
reader = dataset.reader()
438-
_, example_serialized = reader.read(filename_queue)
439446

440-
# Parse a serialized Example proto to extract the image and metadata.
441-
image_buffer, label_index, bbox, _ = parse_example_proto(
442-
example_serialized)
443-
image = image_preprocessing(image_buffer, bbox, train, thread_id)
444-
images_and_labels.append([image, label_index])
447+
if num_readers is None:
448+
num_readers = FLAGS.num_readers
449+
450+
if num_readers < 1:
451+
raise ValueError('Please make num_readers at least 1')
445452

446453
# Approximate number of examples per shard.
447454
examples_per_shard = 1024
@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
451458
# The default input_queue_memory_factor is 16 implying a shuffling queue
452459
# size: examples_per_shard * 16 * 1MB = 17.6GB
453460
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
454-
455-
# Create a queue that produces the examples in batches after shuffling.
456461
if train:
457-
images, label_index_batch = tf.train.shuffle_batch_join(
458-
images_and_labels,
459-
batch_size=batch_size,
462+
examples_queue = tf.RandomShuffleQueue(
460463
capacity=min_queue_examples + 3 * batch_size,
461-
min_after_dequeue=min_queue_examples)
464+
min_after_dequeue=min_queue_examples,
465+
dtypes=[tf.string])
466+
else:
467+
examples_queue = tf.FIFOQueue(
468+
capacity=examples_per_shard + 3 * batch_size,
469+
dtypes=[tf.string])
470+
471+
# Create multiple readers to populate the queue of examples.
472+
if num_readers > 1:
473+
enqueue_ops = []
474+
for _ in range(num_readers):
475+
reader = dataset.reader()
476+
_, value = reader.read(filename_queue)
477+
enqueue_ops.append(examples_queue.enqueue([value]))
478+
479+
tf.train.queue_runner.add_queue_runner(
480+
tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
481+
example_serialized = examples_queue.dequeue()
462482
else:
463-
images, label_index_batch = tf.train.batch_join(
464-
images_and_labels,
465-
batch_size=batch_size,
466-
capacity=min_queue_examples + 3 * batch_size)
483+
reader = dataset.reader()
484+
_, example_serialized = reader.read(filename_queue)
485+
486+
images_and_labels = []
487+
for thread_id in range(num_preprocess_threads):
488+
# Parse a serialized Example proto to extract the image and metadata.
489+
image_buffer, label_index, bbox, _ = parse_example_proto(
490+
example_serialized)
491+
image = image_preprocessing(image_buffer, bbox, train, thread_id)
492+
images_and_labels.append([image, label_index])
493+
494+
images, label_index_batch = tf.train.batch_join(
495+
images_and_labels,
496+
batch_size=batch_size,
497+
capacity=2 * num_preprocess_threads * batch_size)
467498

468499
# Reshape images into these desired dimensions.
469500
height = FLAGS.image_size

0 commit comments

Comments
 (0)