Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit eb0c0df

Browse files
authored
Add dataset info and hyper parameter logging for benchmark. (tensorflow#4152)
* Add dataset info and hyper parameter logging for benchmark. * Address review comments. * Address the view comment for data schema name. * Fix test cases. * Lint fix.
1 parent 8e73530 commit eb0c0df

File tree

6 files changed

+77
-16
lines changed

6 files changed

+77
-16
lines changed

official/benchmark/datastore/schema/benchmark_run.json

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,41 +98,41 @@
9898
"type": "RECORD"
9999
},
100100
{
101-
"description": "The list of hyperparameters of the model.",
101+
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
102102
"fields": [
103103
{
104-
"description": "The name of the hyperparameter.",
104+
"description": "The name of the parameter.",
105105
"mode": "REQUIRED",
106106
"name": "name",
107107
"type": "STRING"
108108
},
109109
{
110-
"description": "The string value of the hyperparameter.",
110+
"description": "The string value of the parameter.",
111111
"mode": "NULLABLE",
112112
"name": "string_value",
113113
"type": "STRING"
114114
},
115115
{
116-
"description": "The bool value of the hyperparameter.",
116+
"description": "The bool value of the parameter.",
117117
"mode": "NULLABLE",
118118
"name": "bool_value",
119119
"type": "STRING"
120120
},
121121
{
122-
"description": "The int/long value of the hyperparameter.",
122+
"description": "The int/long value of the parameter.",
123123
"mode": "NULLABLE",
124124
"name": "long_value",
125125
"type": "INTEGER"
126126
},
127127
{
128-
"description": "The double/float value of hyperparameter.",
128+
"description": "The double/float value of parameter.",
129129
"mode": "NULLABLE",
130130
"name": "float_value",
131131
"type": "FLOAT"
132132
}
133133
],
134134
"mode": "REPEATED",
135-
"name": "hyperparameter",
135+
"name": "run_parameters",
136136
"type": "RECORD"
137137
},
138138
{

official/resnet/cifar10_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
'validation': 10000,
4343
}
4444

45+
DATASET_NAME = 'CIFAR-10'
46+
4547

4648
###############################################################################
4749
# Data processing
@@ -237,7 +239,7 @@ def run_cifar(flags_obj):
237239
or input_fn)
238240

239241
resnet_run_loop.resnet_main(
240-
flags_obj, cifar10_model_fn, input_function,
242+
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
241243
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
242244

243245

official/resnet/imagenet_main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_NUM_TRAIN_FILES = 1024
4242
_SHUFFLE_BUFFER = 1500
4343

44+
DATASET_NAME = 'ImageNet'
4445

4546
###############################################################################
4647
# Data processing
@@ -312,7 +313,7 @@ def run_imagenet(flags_obj):
312313
or input_fn)
313314

314315
resnet_run_loop.resnet_main(
315-
flags_obj, imagenet_model_fn, input_function,
316+
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
316317
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
317318

318319

official/resnet/resnet_run_loop.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def per_device_batch_size(batch_size, num_gpus):
331331
return int(batch_size / num_gpus)
332332

333333

334-
def resnet_main(flags_obj, model_function, input_function, shape=None):
334+
def resnet_main(
335+
flags_obj, model_function, input_function, dataset_name, shape=None):
335336
"""Shared main loop for ResNet Models.
336337
337338
Args:
@@ -342,6 +343,8 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
342343
input_function: the function that processes the dataset and returns a
343344
dataset that the estimator can train on. This will be wrapped with
344345
all the relevant flags for running and passed to estimator.
346+
dataset_name: the name of the dataset for training and evaluation. This is
347+
used for logging purpose.
345348
shape: list of ints representing the shape of the images used for training.
346349
This is only used if flags_obj.export_dir is passed.
347350
"""
@@ -381,8 +384,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
381384
'dtype': flags_core.get_tf_dtype(flags_obj)
382385
})
383386

387+
run_params = {
388+
'batch_size': flags_obj.batch_size,
389+
'dtype': flags_core.get_tf_dtype(flags_obj),
390+
'resnet_size': flags_obj.resnet_size,
391+
'resnet_version': flags_obj.version,
392+
'synthetic_data': flags_obj.use_synthetic_data,
393+
'train_epochs': flags_obj.train_epochs,
394+
}
384395
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
385-
benchmark_logger.log_run_info('resnet')
396+
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
386397

387398
train_hooks = hooks_helper.get_train_hooks(
388399
flags_obj.hooks,

official/utils/logs/logger.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def log_metric(self, name, value, unit=None, global_step=None, extras=None):
109109
"Name %s, value %d, unit %s, global_step %d, extras %s",
110110
name, value, unit, global_step, extras)
111111

112-
def log_run_info(self, model_name):
113-
tf.logging.info("Benchmark run: %s", _gather_run_info(model_name))
112+
def log_run_info(self, model_name, dataset_name, run_params):
113+
tf.logging.info("Benchmark run: %s",
114+
_gather_run_info(model_name, dataset_name, run_params))
114115

115116

116117
class BenchmarkFileLogger(BaseBenchmarkLogger):
@@ -159,15 +160,18 @@ def log_metric(self, name, value, unit=None, global_step=None, extras=None):
159160
tf.logging.warning("Failed to dump metric to log file: "
160161
"name %s, value %s, error %s", name, value, e)
161162

162-
def log_run_info(self, model_name):
163+
def log_run_info(self, model_name, dataset_name, run_params):
163164
"""Collect most of the TF runtime information for the local env.
164165
165166
The schema of the run info follows official/benchmark/datastore/schema.
166167
167168
Args:
168169
model_name: string, the name of the model.
170+
dataset_name: string, the name of dataset for training and evaluation.
171+
run_params: dict, the dictionary of parameters for the run, it could
172+
include hyperparameters or other params that are important for the run.
169173
"""
170-
run_info = _gather_run_info(model_name)
174+
run_info = _gather_run_info(model_name, dataset_name, run_params)
171175

172176
with tf.gfile.GFile(os.path.join(
173177
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
@@ -179,15 +183,17 @@ def log_run_info(self, model_name):
179183
e)
180184

181185

182-
def _gather_run_info(model_name):
186+
def _gather_run_info(model_name, dataset_name, run_params):
183187
"""Collect the benchmark run information for the local environment."""
184188
run_info = {
185189
"model_name": model_name,
190+
"dataset": {"name": dataset_name},
186191
"machine_config": {},
187192
"run_date": datetime.datetime.utcnow().strftime(
188193
_DATE_TIME_FORMAT_PATTERN)}
189194
_collect_tensorflow_info(run_info)
190195
_collect_tensorflow_environment_variables(run_info)
196+
_collect_run_params(run_info, run_params)
191197
_collect_cpu_info(run_info)
192198
_collect_gpu_info(run_info)
193199
_collect_memory_info(run_info)
@@ -199,6 +205,21 @@ def _collect_tensorflow_info(run_info):
199205
"version": tf.VERSION, "git_hash": tf.GIT_VERSION}
200206

201207

208+
def _collect_run_params(run_info, run_params):
209+
"""Log the parameter information for the benchmark run."""
210+
def process_param(name, value):
211+
type_check = {
212+
str: {"name": name, "string_value": value},
213+
int: {"name": name, "long_value": value},
214+
bool: {"name": name, "bool_value": str(value)},
215+
float: {"name": name, "float_value": value},
216+
}
217+
return type_check.get(type(value),
218+
{"name": name, "string_value": str(value)})
219+
if run_params:
220+
run_info["run_parameters"] = [
221+
process_param(k, v) for k, v in sorted(run_params.items())]
222+
202223
def _collect_tensorflow_environment_variables(run_info):
203224
run_info["tensorflow_environment_variables"] = [
204225
{"name": k, "value": v}

official/utils/logs/logger_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,32 @@ def test_collect_tensorflow_info(self):
180180
self.assertEqual(run_info["tensorflow_version"]["version"], tf.VERSION)
181181
self.assertEqual(run_info["tensorflow_version"]["git_hash"], tf.GIT_VERSION)
182182

183+
def test_collect_run_params(self):
184+
run_info = {}
185+
run_parameters = {
186+
"batch_size": 32,
187+
"synthetic_data": True,
188+
"train_epochs": 100.00,
189+
"dtype": "fp16",
190+
"resnet_size": 50,
191+
"random_tensor": tf.constant(2.0)
192+
}
193+
logger._collect_run_params(run_info, run_parameters)
194+
self.assertEqual(len(run_info["run_parameters"]), 6)
195+
self.assertEqual(run_info["run_parameters"][0],
196+
{"name": "batch_size", "long_value": 32})
197+
self.assertEqual(run_info["run_parameters"][1],
198+
{"name": "dtype", "string_value": "fp16"})
199+
self.assertEqual(run_info["run_parameters"][2],
200+
{"name": "random_tensor", "string_value":
201+
"Tensor(\"Const:0\", shape=(), dtype=float32)"})
202+
self.assertEqual(run_info["run_parameters"][3],
203+
{"name": "resnet_size", "long_value": 50})
204+
self.assertEqual(run_info["run_parameters"][4],
205+
{"name": "synthetic_data", "bool_value": "True"})
206+
self.assertEqual(run_info["run_parameters"][5],
207+
{"name": "train_epochs", "float_value": 100.00})
208+
183209
def test_collect_tensorflow_environment_variables(self):
184210
os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "1"
185211
os.environ["TF_OTHER"] = "2"

0 commit comments

Comments
 (0)