34
34
from official .utils .logs import hooks_helper
35
35
from official .utils .logs import logger
36
36
from official .utils .misc import model_helpers
37
-
37
+ from tensorflow . contrib . data . python . ops import threadpool
38
38
39
39
################################################################################
40
40
# Functions for input processing.
41
41
################################################################################
42
- def process_record_dataset (dataset , is_training , batch_size , shuffle_buffer ,
43
- parse_record_fn , num_epochs = 1 ):
42
+ def process_record_dataset (dataset , is_training , global_batch_size ,
43
+ shuffle_buffer , parse_record_fn , num_epochs = 1 ,
44
+ num_gpus = 1 , datasets_num_private_threads = None ):
44
45
"""Given a Dataset with raw records, return an iterator over the records.
45
46
46
47
Args:
47
48
dataset: A Dataset representing raw records
48
49
is_training: A boolean denoting whether the input is for training.
49
- batch_size : The number of samples per batch.
50
+ global_batch_size : The number of samples per batch (across devices) .
50
51
shuffle_buffer: The buffer size to use when shuffling records. A larger
51
52
value results in better randomness, but smaller values reduce startup
52
53
time and use less memory.
53
54
parse_record_fn: A function that takes a raw record and returns the
54
55
corresponding (image, label) pair.
55
56
num_epochs: The number of epochs to repeat the dataset.
57
+ num_gpus: The number of GPUs.
58
+ datasets_num_private_threads: Number of threads for a private
59
+ threadpool created for all datasets computation.
56
60
57
61
Returns:
58
62
Dataset of (image, label) pairs ready for iteration.
59
63
"""
60
64
61
65
# We prefetch a batch at a time, This can help smooth out the time taken to
62
66
# load input files as we go through shuffling and processing.
63
- dataset = dataset .prefetch (buffer_size = batch_size )
67
+ dataset = dataset .prefetch (buffer_size = global_batch_size )
64
68
if is_training :
65
69
# Shuffle the records. Note that we shuffle before repeating to ensure
66
70
# that the shuffling respects epoch boundaries.
67
- dataset = dataset .shuffle (buffer_size = shuffle_buffer )
68
-
69
- # If we are training over multiple epochs before evaluating, repeat the
70
- # dataset for the appropriate number of epochs.
71
- dataset = dataset .repeat (num_epochs )
71
+ # If we are training over multiple epochs before evaluating, repeat the
72
+ # dataset for the appropriate number of epochs.
73
+ # Using the fused shuffle_and_repeat method gives better performance.
74
+ dataset = dataset .apply (tf .contrib .data .shuffle_and_repeat (
75
+ buffer_size = shuffle_buffer , num_epochs ))
76
+ else :
77
+ dataset = dataset .repeat (num_epochs )
72
78
73
79
# Parse the raw records into images and labels. Testing has shown that setting
74
80
# num_parallel_batches > 1 produces no improvement in throughput, since
75
81
# batch_size is almost always much greater than the number of CPU cores.
76
82
dataset = dataset .apply (
77
83
tf .contrib .data .map_and_batch (
78
84
lambda value : parse_record_fn (value , is_training ),
79
- batch_size = batch_size ,
80
- num_parallel_batches = 1 ))
85
+ batch_size = per_device_batch_size (global_batch_size , num_gpus ),
86
+ num_parallel_batches = num_gpus ,
87
+ drop_remainder = True ))
81
88
82
89
# Operations between the final prefetch and the get_next call to the iterator
83
90
# will happen synchronously during run time. We prefetch here again to
84
91
# background all of the above processing work and keep it out of the
85
92
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
86
- # allows DistributionStrategies to adjust how many batches to fetch based
93
+ # allows TensorFlow to adjust how many batches to fetch based
87
94
# on how many devices are present.
88
- dataset .prefetch (buffer_size = tf .contrib .data .AUTOTUNE )
95
+ dataset .prefetch (buffer_size = num_gpus )
96
+
97
+ if datasets_num_private_threads :
98
+ dataset = threadpool .override_threadpool (
99
+ dataset ,
100
+ threadpool .PrivateThreadPool (
101
+ datasets_num_private_threads ,
102
+ display_name = "input_pipeline_thread_pool" ))
89
103
90
104
return dataset
91
105
@@ -109,7 +123,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
109
123
"""
110
124
def input_fn (is_training , data_dir , batch_size , * args , ** kwargs ): # pylint: disable=unused-argument
111
125
images = tf .zeros ((batch_size , height , width , num_channels ), tf .float32 )
112
- labels = tf .zeros ((batch_size , num_classes ), tf .int32 )
126
+ labels = tf .zeros ((batch_size ), tf .int32 )
113
127
return tf .data .Dataset .from_tensors ((images , labels )).repeat ()
114
128
115
129
return input_fn
@@ -225,8 +239,8 @@ def resnet_model_fn(features, labels, mode, model_class,
225
239
})
226
240
227
241
# Calculate loss, which includes softmax cross entropy and L2 regularization.
228
- cross_entropy = tf .losses .softmax_cross_entropy (
229
- logits = logits , onehot_labels = labels )
242
+ cross_entropy = tf .losses .sparse_softmax_cross_entropy (
243
+ logits = logits , labels = labels )
230
244
231
245
# Create a tensor named cross_entropy for logging purposes.
232
246
tf .identity (cross_entropy , name = 'cross_entropy' )
@@ -280,8 +294,7 @@ def exclude_batch_norm(name):
280
294
train_op = None
281
295
282
296
if not tf .contrib .distribute .has_distribution_strategy ():
283
- accuracy = tf .metrics .accuracy (
284
- tf .argmax (labels , axis = 1 ), predictions ['classes' ])
297
+ accuracy = tf .metrics .accuracy (labels , predictions ['classes' ])
285
298
else :
286
299
# Metrics are currently not compatible with distribution strategies during
287
300
# training. This does not affect the overall performance of the model.
@@ -352,6 +365,9 @@ def resnet_main(
352
365
353
366
# Using the Winograd non-fused algorithms provides a small performance boost.
354
367
os .environ ['TF_ENABLE_WINOGRAD_NONFUSED' ] = '1'
368
+ os .environ ['TF_GPU_THREAD_MODE' ] = flags_obj .tf_gpu_thread_mode
369
+ os .environ ['TF_GPU_THREAD_COUNT' ] = flags_obj .tf_gpu_thread_count
370
+
355
371
356
372
# Create session config based on values of inter_op_parallelism_threads and
357
373
# intra_op_parallelism_threads. Note that we default to having
@@ -391,7 +407,7 @@ def resnet_main(
391
407
'resnet_size' : flags_obj .resnet_size ,
392
408
'resnet_version' : flags_obj .resnet_version ,
393
409
'synthetic_data' : flags_obj .use_synthetic_data ,
394
- 'train_epochs' : flags_obj .train_epochs ,
410
+ 'train_epochs' : flags_obj .train_epochs
395
411
}
396
412
benchmark_logger = logger .config_benchmark_logger (flags_obj .benchmark_log_dir )
397
413
benchmark_logger .log_run_info ('resnet' , dataset_name , run_params )
@@ -404,16 +420,17 @@ def resnet_main(
404
420
def input_fn_train ():
405
421
return input_function (
406
422
is_training = True , data_dir = flags_obj .data_dir ,
407
- batch_size = per_device_batch_size (
408
- flags_obj .batch_size , flags_core . get_num_gpus ( flags_obj )) ,
409
- num_epochs = flags_obj . epochs_between_evals )
423
+ global_batch_size = flags_obj . batch_size ,
424
+ num_epochs = flags_obj .epochs_between_evals ,
425
+ num_gpus = flags_core . get_num_gpus ( flags_obj ) )
410
426
411
427
def input_fn_eval ():
412
428
return input_function (
413
429
is_training = False , data_dir = flags_obj .data_dir ,
414
- batch_size = per_device_batch_size (
415
- flags_obj .batch_size , flags_core .get_num_gpus (flags_obj )),
416
- num_epochs = 1 )
430
+ global_batch_size = flags_obj .batch_size ,
431
+ num_epochs = 1 ,
432
+ num_gpus = flags_core .get_num_gpus (flags_obj ))
433
+
417
434
418
435
total_training_cycle = (flags_obj .train_epochs //
419
436
flags_obj .epochs_between_evals )
@@ -451,7 +468,11 @@ def input_fn_eval():
451
468
def define_resnet_flags (resnet_size_choices = None ):
452
469
"""Add flags and validators for ResNet."""
453
470
flags_core .define_base ()
454
- flags_core .define_performance (num_parallel_calls = False )
471
+ flags_core .define_performance (
472
+ num_parallel_calls = False ,
473
+ datasets_num_private_threads = True ,
474
+ tf_gpu_thread_mode = True ,
475
+ tf_gpu_thread_count = True )
455
476
flags_core .define_image ()
456
477
flags_core .define_benchmark ()
457
478
flags .adopt_module_key_flags (flags_core )
0 commit comments