Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c4c58d2

Browse files
committed
Input perf improvements and thread related flags
1 parent 827d250 commit c4c58d2

File tree

4 files changed

+92
-41
lines changed

4 files changed

+92
-41
lines changed

official/resnet/cifar10_main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def parse_record(raw_record, is_training):
7373
# The first byte represents the label, which we convert from uint8 to int32
7474
# and then to one-hot.
7575
label = tf.cast(record_vector[0], tf.int32)
76-
label = tf.one_hot(label, _NUM_CLASSES)
7776

7877
# The remaining bytes after the label represent the image, which we reshape
7978
# from [depth * height * width] to [depth, height, width].
@@ -107,14 +106,18 @@ def preprocess_image(image, is_training):
107106
return image
108107

109108

110-
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
109+
def input_fn(is_training, data_dir, global_batch_size, num_epochs=1, num_gpus=1):
111110
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
112111
113112
Args:
114113
is_training: A boolean denoting whether the input is for training.
115114
data_dir: The directory containing the input data.
116-
batch_size: The number of samples per batch.
115+
global_batch_size: The number of samples per batch.
117116
num_epochs: The number of epochs to repeat the dataset.
117+
num_gpus: The number of GPUs.
118+
datasets_num_private_threads: Number of threads for a private
119+
threadpool created for all datasets computation.
120+
118121
119122
Returns:
120123
A dataset that can be used for iteration.
@@ -123,8 +126,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
123126
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
124127

125128
return resnet_run_loop.process_record_dataset(
126-
dataset, is_training, batch_size, _NUM_IMAGES['train'],
127-
parse_record, num_epochs,
129+
dataset, is_training, global_batch_size, _NUM_IMAGES['train'],
130+
parse_record, num_epochs, num_gpus, datasets_num_private_threads
128131
)
129132

130133

official/resnet/imagenet_main.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
}
4040

4141
_NUM_TRAIN_FILES = 1024
42-
_SHUFFLE_BUFFER = 1500
42+
_SHUFFLE_BUFFER = 10000
4343

4444
DATASET_NAME = 'ImageNet'
4545

@@ -152,19 +152,21 @@ def parse_record(raw_record, is_training):
152152
num_channels=_NUM_CHANNELS,
153153
is_training=is_training)
154154

155-
label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)
156-
157155
return image, label
158156

159157

160-
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
158+
def input_fn(is_training, data_dir, global_batch_size, num_epochs=1,
159+
num_gpus=1, datasets_num_private_threads=None):
161160
"""Input function which provides batches for train or eval.
162161
163162
Args:
164163
is_training: A boolean denoting whether the input is for training.
165164
data_dir: The directory containing the input data.
166-
batch_size: The number of samples per batch.
165+
global_batch_size: The number of samples per batch.
167166
num_epochs: The number of epochs to repeat the dataset.
167+
num_gpus: The number of GPUs.
168+
datasets_num_private_threads: Number of threads for a private
169+
threadpool created for all datasets computation.
168170
169171
Returns:
170172
A dataset that can be used for iteration.
@@ -177,11 +179,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
177179
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
178180

179181
# Convert to individual records
180-
dataset = dataset.flat_map(tf.data.TFRecordDataset)
182+
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
183+
tf.data.TFRecordDataset, cycle_length=10))
181184

182185
return resnet_run_loop.process_record_dataset(
183-
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
184-
num_epochs
186+
dataset, is_training, global_batch_size, _SHUFFLE_BUFFER, parse_record,
187+
num_epochs, num_gpus, datasets_num_private_threads
185188
)
186189

187190

official/resnet/resnet_run_loop.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,58 +34,72 @@
3434
from official.utils.logs import hooks_helper
3535
from official.utils.logs import logger
3636
from official.utils.misc import model_helpers
37-
37+
from tensorflow.contrib.data.python.ops import threadpool
3838

3939
################################################################################
4040
# Functions for input processing.
4141
################################################################################
42-
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
43-
parse_record_fn, num_epochs=1):
42+
def process_record_dataset(dataset, is_training, global_batch_size,
43+
shuffle_buffer, parse_record_fn, num_epochs=1,
44+
num_gpus=1, datasets_num_private_threads=None):
4445
"""Given a Dataset with raw records, return an iterator over the records.
4546
4647
Args:
4748
dataset: A Dataset representing raw records
4849
is_training: A boolean denoting whether the input is for training.
49-
batch_size: The number of samples per batch.
50+
global_batch_size: The number of samples per batch (across devices).
5051
shuffle_buffer: The buffer size to use when shuffling records. A larger
5152
value results in better randomness, but smaller values reduce startup
5253
time and use less memory.
5354
parse_record_fn: A function that takes a raw record and returns the
5455
corresponding (image, label) pair.
5556
num_epochs: The number of epochs to repeat the dataset.
57+
num_gpus: The number of GPUs.
58+
datasets_num_private_threads: Number of threads for a private
59+
threadpool created for all datasets computation.
5660
5761
Returns:
5862
Dataset of (image, label) pairs ready for iteration.
5963
"""
6064

6165
# We prefetch a batch at a time, This can help smooth out the time taken to
6266
# load input files as we go through shuffling and processing.
63-
dataset = dataset.prefetch(buffer_size=batch_size)
67+
dataset = dataset.prefetch(buffer_size=global_batch_size)
6468
if is_training:
6569
# Shuffle the records. Note that we shuffle before repeating to ensure
6670
# that the shuffling respects epoch boundaries.
67-
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
68-
69-
# If we are training over multiple epochs before evaluating, repeat the
70-
# dataset for the appropriate number of epochs.
71-
dataset = dataset.repeat(num_epochs)
71+
# If we are training over multiple epochs before evaluating, repeat the
72+
# dataset for the appropriate number of epochs.
73+
# Using the fused shuffle_and_repeat method gives better performance.
74+
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(
75+
buffer_size=shuffle_buffer, num_epochs))
76+
else:
77+
dataset = dataset.repeat(num_epochs)
7278

7379
# Parse the raw records into images and labels. Testing has shown that setting
7480
# num_parallel_batches > 1 produces no improvement in throughput, since
7581
# batch_size is almost always much greater than the number of CPU cores.
7682
dataset = dataset.apply(
7783
tf.contrib.data.map_and_batch(
7884
lambda value: parse_record_fn(value, is_training),
79-
batch_size=batch_size,
80-
num_parallel_batches=1))
85+
batch_size=per_device_batch_size(global_batch_size, num_gpus),
86+
num_parallel_batches=num_gpus,
87+
drop_remainder=True))
8188

8289
# Operations between the final prefetch and the get_next call to the iterator
8390
# will happen synchronously during run time. We prefetch here again to
8491
# background all of the above processing work and keep it out of the
8592
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
86-
# allows DistributionStrategies to adjust how many batches to fetch based
93+
# allows TensorFlow to adjust how many batches to fetch based
8794
# on how many devices are present.
88-
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
95+
dataset.prefetch(buffer_size=num_gpus)
96+
97+
if datasets_num_private_threads:
98+
dataset = threadpool.override_threadpool(
99+
dataset,
100+
threadpool.PrivateThreadPool(
101+
datasets_num_private_threads,
102+
display_name="input_pipeline_thread_pool"))
89103

90104
return dataset
91105

@@ -109,7 +123,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
109123
"""
110124
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
111125
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
112-
labels = tf.zeros((batch_size, num_classes), tf.int32)
126+
labels = tf.zeros((batch_size), tf.int32)
113127
return tf.data.Dataset.from_tensors((images, labels)).repeat()
114128

115129
return input_fn
@@ -225,8 +239,8 @@ def resnet_model_fn(features, labels, mode, model_class,
225239
})
226240

227241
# Calculate loss, which includes softmax cross entropy and L2 regularization.
228-
cross_entropy = tf.losses.softmax_cross_entropy(
229-
logits=logits, onehot_labels=labels)
242+
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
243+
logits=logits, labels=labels)
230244

231245
# Create a tensor named cross_entropy for logging purposes.
232246
tf.identity(cross_entropy, name='cross_entropy')
@@ -280,8 +294,7 @@ def exclude_batch_norm(name):
280294
train_op = None
281295

282296
if not tf.contrib.distribute.has_distribution_strategy():
283-
accuracy = tf.metrics.accuracy(
284-
tf.argmax(labels, axis=1), predictions['classes'])
297+
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
285298
else:
286299
# Metrics are currently not compatible with distribution strategies during
287300
# training. This does not affect the overall performance of the model.
@@ -352,6 +365,9 @@ def resnet_main(
352365

353366
# Using the Winograd non-fused algorithms provides a small performance boost.
354367
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
368+
os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode
369+
os.environ['TF_GPU_THREAD_COUNT'] = flags_obj.tf_gpu_thread_count
370+
355371

356372
# Create session config based on values of inter_op_parallelism_threads and
357373
# intra_op_parallelism_threads. Note that we default to having
@@ -391,7 +407,7 @@ def resnet_main(
391407
'resnet_size': flags_obj.resnet_size,
392408
'resnet_version': flags_obj.resnet_version,
393409
'synthetic_data': flags_obj.use_synthetic_data,
394-
'train_epochs': flags_obj.train_epochs,
410+
'train_epochs': flags_obj.train_epochs
395411
}
396412
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
397413
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
@@ -404,16 +420,17 @@ def resnet_main(
404420
def input_fn_train():
405421
return input_function(
406422
is_training=True, data_dir=flags_obj.data_dir,
407-
batch_size=per_device_batch_size(
408-
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
409-
num_epochs=flags_obj.epochs_between_evals)
423+
global_batch_size=flags_obj.batch_size,
424+
num_epochs=flags_obj.epochs_between_evals,
425+
num_gpus=flags_core.get_num_gpus(flags_obj))
410426

411427
def input_fn_eval():
412428
return input_function(
413429
is_training=False, data_dir=flags_obj.data_dir,
414-
batch_size=per_device_batch_size(
415-
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
416-
num_epochs=1)
430+
global_batch_size=flags_obj.batch_size,
431+
num_epochs=1,
432+
num_gpus=flags_core.get_num_gpus(flags_obj))
433+
417434

418435
total_training_cycle = (flags_obj.train_epochs //
419436
flags_obj.epochs_between_evals)
@@ -451,7 +468,11 @@ def input_fn_eval():
451468
def define_resnet_flags(resnet_size_choices=None):
452469
"""Add flags and validators for ResNet."""
453470
flags_core.define_base()
454-
flags_core.define_performance(num_parallel_calls=False)
471+
flags_core.define_performance(
472+
num_parallel_calls=False,
473+
datasets_num_private_threads=True,
474+
tf_gpu_thread_mode=True,
475+
tf_gpu_thread_count=True)
455476
flags_core.define_image()
456477
flags_core.define_benchmark()
457478
flags.adopt_module_key_flags(flags_core)

official/utils/flags/_performance.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def get_loss_scale(flags_obj):
4444

4545

4646
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
47-
synthetic_data=True, max_train_steps=True, dtype=True):
47+
synthetic_data=True, max_train_steps=True, dtype=True,
48+
tf_gpu_thread_mode=False, tf_gpu_thread_count=False,
49+
datasets_num_private_threads=False):
4850
"""Register flags for specifying performance tuning arguments.
4951
5052
Args:
@@ -129,4 +131,26 @@ def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
129131

130132
return loss_scale > 0
131133

134+
if tf_gpu_thread_mode:
135+
flags.DEFINE_string(
136+
name="tf_gpu_thread_mode", short_name="gt_mode", default="global",
137+
help=help_wrap(
138+
"Whether and how the GPU device uses its own threadpool.")
139+
)
140+
141+
if tf_gpu_thread_count:
142+
flags.DEFINE_integer(
143+
name="tf_gpu_thread_count", short_name="gt_count", default=2,
144+
help=help_wrap("How many threads to reserve for GPU based on mode.")
145+
)
146+
147+
if datasets_num_private_threads:
148+
flags.DEFINE_integer(
149+
name="datasets_num_private_threads", short_name="dataset_thread_count",
150+
default=None,
151+
help=help_wrap(
152+
"Number of threads for a private threadpool created for all datasets"
153+
"computation..")
154+
)
155+
132156
return key_flags

0 commit comments

Comments
 (0)