@@ -54,6 +54,11 @@ def run_keras_model_benchmark(_):
54
54
raise AssertionError ("The --model command line argument should "
55
55
"be a key in the `MODELS` dictionary." )
56
56
57
+ # Check if eager execution is enabled
58
+ if FLAGS .eager :
59
+ tf .logging .info ("Eager execution is enabled..." )
60
+ tf .enable_eager_execution ()
61
+
57
62
# Load the model
58
63
tf .logging .info ("Benchmark on {} model..." .format (FLAGS .model ))
59
64
keras_model = MODELS [FLAGS .model ]
@@ -64,30 +69,35 @@ def run_keras_model_benchmark(_):
64
69
if FLAGS .use_synthetic_data :
65
70
tf .logging .info ("Using synthetic dataset..." )
66
71
dataset_name += "_Synthetic"
67
- train_num_images = FLAGS .batch_size
68
- val_num_images = FLAGS .batch_size
69
72
train_dataset = dataset .generate_synthetic_input_dataset (
70
- FLAGS .model , train_num_images )
73
+ FLAGS .model , FLAGS . batch_size )
71
74
val_dataset = dataset .generate_synthetic_input_dataset (
72
- FLAGS .model , val_num_images )
75
+ FLAGS .model , FLAGS . batch_size )
73
76
else :
74
77
raise ValueError ("Only synthetic dataset is supported!" )
75
78
76
79
# If run with multiple GPUs
80
+ # If eager execution is enabled, only one GPU is utilized even if multiple
81
+ # GPUs are provided.
77
82
num_gpus = flags_core .get_num_gpus (FLAGS )
78
- if num_gpus > 0 :
83
+ if num_gpus > 1 :
84
+ if FLAGS .eager :
85
+ tf .logging .warning (
86
+ "{} GPUs are provided, but only one GPU is utilized as "
87
+ "eager execution is enabled." .format (num_gpus ))
79
88
model = tf .keras .utils .multi_gpu_model (model , gpus = num_gpus )
80
89
81
- # Configure the model
82
90
model .compile (loss = "categorical_crossentropy" ,
83
- optimizer = "sgd" ,
91
+ optimizer = tf . train . AdamOptimizer () ,
84
92
metrics = ["accuracy" ])
85
93
86
94
# Create benchmark logger for benchmark logging
87
95
run_params = {
88
96
"batch_size" : FLAGS .batch_size ,
89
97
"synthetic_data" : FLAGS .use_synthetic_data ,
90
- "train_epochs" : FLAGS .train_epochs
98
+ "train_epochs" : FLAGS .train_epochs ,
99
+ "num_train_images" : FLAGS .num_images ,
100
+ "num_eval_images" : FLAGS .num_images ,
91
101
}
92
102
93
103
benchmark_logger = logger .get_benchmark_logger ()
@@ -108,8 +118,8 @@ def run_keras_model_benchmark(_):
108
118
epochs = FLAGS .train_epochs ,
109
119
callbacks = callbacks ,
110
120
validation_data = val_dataset ,
111
- steps_per_epoch = int (np .ceil (train_num_images / FLAGS .batch_size )),
112
- validation_steps = int (np .ceil (val_num_images / FLAGS .batch_size ))
121
+ steps_per_epoch = int (np .ceil (FLAGS . num_images / FLAGS .batch_size )),
122
+ validation_steps = int (np .ceil (FLAGS . num_images / FLAGS .batch_size ))
113
123
)
114
124
115
125
tf .logging .info ("Logging the evaluation results..." )
@@ -118,7 +128,7 @@ def run_keras_model_benchmark(_):
118
128
"accuracy" : history .history ["val_acc" ][epoch ],
119
129
"loss" : history .history ["val_loss" ][epoch ],
120
130
tf .GraphKeys .GLOBAL_STEP : (epoch + 1 ) * np .ceil (
121
- train_num_images / FLAGS .batch_size )
131
+ FLAGS . num_images / FLAGS .batch_size )
122
132
}
123
133
benchmark_logger .log_evaluation_result (eval_results )
124
134
@@ -146,6 +156,18 @@ def define_keras_benchmark_flags():
146
156
help = flags_core .help_wrap (
147
157
"Model to be benchmarked." ))
148
158
159
+ flags .DEFINE_integer (
160
+ name = "num_images" , default = 1000 ,
161
+ help = flags_core .help_wrap (
162
+ "The number of synthetic images for training and evaluation. The "
163
+ "default value is 1000." ))
164
+
165
+ flags .DEFINE_boolean (
166
+ name = "eager" , default = False , help = flags_core .help_wrap (
167
+ "To enable eager execution. Note that if eager execution is enabled, "
168
+ "only one GPU is utilized even if multiple GPUs are provided and "
169
+ "multi_gpu_model is used." ))
170
+
149
171
flags .DEFINE_list (
150
172
name = "callbacks" ,
151
173
default = ["ExamplesPerSecondCallback" , "LoggingMetricCallback" ],
@@ -159,6 +181,7 @@ def main(_):
159
181
with logger .benchmark_context (FLAGS ):
160
182
run_keras_model_benchmark (FLAGS )
161
183
184
+
162
185
if __name__ == "__main__" :
163
186
tf .logging .set_verbosity (tf .logging .INFO )
164
187
define_keras_benchmark_flags ()
0 commit comments