45
45
################################################################################
46
46
def process_record_dataset (dataset , is_training , batch_size , shuffle_buffer ,
47
47
parse_record_fn , num_epochs = 1 , num_gpus = None ,
48
- examples_per_epoch = None ):
48
+ examples_per_epoch = None , dtype = tf . float32 ):
49
49
"""Given a Dataset with raw records, return an iterator over the records.
50
50
51
51
Args:
@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
60
60
num_epochs: The number of epochs to repeat the dataset.
61
61
num_gpus: The number of gpus used for training.
62
62
examples_per_epoch: The number of examples in an epoch.
63
+ dtype: Data type to use for images/features.
63
64
64
65
Returns:
65
66
Dataset of (image, label) pairs ready for iteration.
@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
92
93
# batch_size is almost always much greater than the number of CPU cores.
93
94
dataset = dataset .apply (
94
95
tf .contrib .data .map_and_batch (
95
- lambda value : parse_record_fn (value , is_training ),
96
+ lambda value : parse_record_fn (value , is_training , dtype ),
96
97
batch_size = batch_size ,
97
98
num_parallel_batches = 1 ,
98
99
drop_remainder = False ))
@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
248
249
249
250
# Generate a summary node for the images
250
251
tf .summary .image ('images' , features , max_outputs = 6 )
251
- # TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove .
252
- features = tf . cast ( features , dtype )
252
+ # Checks that features/images have same data type being used for calculations .
253
+ assert features . dtype == dtype
253
254
254
255
model = model_class (resnet_size , data_format , resnet_version = resnet_version ,
255
256
dtype = dtype )
@@ -454,14 +455,16 @@ def input_fn_train(num_epochs):
454
455
batch_size = distribution_utils .per_device_batch_size (
455
456
flags_obj .batch_size , flags_core .get_num_gpus (flags_obj )),
456
457
num_epochs = num_epochs ,
457
- num_gpus = flags_core .get_num_gpus (flags_obj ))
458
+ num_gpus = flags_core .get_num_gpus (flags_obj ),
459
+ dtype = flags_core .get_tf_dtype (flags_obj ))
458
460
459
461
def input_fn_eval ():
460
462
return input_function (
461
463
is_training = False , data_dir = flags_obj .data_dir ,
462
464
batch_size = distribution_utils .per_device_batch_size (
463
465
flags_obj .batch_size , flags_core .get_num_gpus (flags_obj )),
464
- num_epochs = 1 )
466
+ num_epochs = 1 ,
467
+ dtype = flags_core .get_tf_dtype (flags_obj ))
465
468
466
469
if flags_obj .eval_only or not flags_obj .train_epochs :
467
470
# If --eval_only is set, perform a single loop with zero train epochs.
@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
533
536
'If not None initialize all the network except the final layer with '
534
537
'these values' ))
535
538
flags .DEFINE_boolean (
536
- name = " eval_only" , default = False ,
539
+ name = ' eval_only' , default = False ,
537
540
help = flags_core .help_wrap ('Skip training and only perform evaluation on '
538
541
'the latest checkpoint.' ))
539
542
0 commit comments