-
Notifications
You must be signed in to change notification settings - Fork 45.6k
Add dataset info and hyper parameter logging for benchmark. #4152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
official/resnet/cifar10_main.py
Outdated
@@ -243,7 +243,7 @@ def main(flags_obj): | |||
or input_fn) | |||
|
|||
resnet_run_loop.resnet_main( | |||
flags_obj, cifar10_model_fn, input_function, | |||
flags_obj, cifar10_model_fn, input_function, 'CIFAR-10', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make this a module-level var or a class-level var? Ditto for Imagenet below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
official/resnet/resnet_run_loop.py
Outdated
@@ -402,8 +404,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): | |||
'dtype': flags_core.get_tf_dtype(flags_obj) | |||
}) | |||
|
|||
hyperparams = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitting here, but this is a mix of hyperparams and run params. dtype, version, synthetic data are not what would typically be considered hyperparams.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked offline. I will rename the data schema into run_parameter which covers both hyper/non-hyper params. Will do this in a separate PR.
official/resnet/resnet_run_loop.py
Outdated
@@ -361,6 +361,8 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): | |||
input_function: the function that processes the dataset and returns a | |||
dataset that the estimator can train on. This will be wrapped with | |||
all the relevant flags for running and passed to estimator. | |||
dataset: the name of the dataset for training and evaluation. This is used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe dataset_name to be explicit? Dataset sounds like the actual data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
official/utils/logs/logger.py
Outdated
"""Collect most of the TF runtime information for the local env. | ||
|
||
The schema of the run info follows official/benchmark/datastore/schema. | ||
|
||
Args: | ||
model_name: string, the name of the model. | ||
dataset_name: string, the name of dataset for training and evaluation. | ||
hyperparams: dict, the dictionary of hyper parameters for the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted above, but maybe generalize this naming, or split out hyper from run, though the latter requires some semantic parsing of the params in ways most people don't care about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, Will do this in a separate PR.
str: {"name": name, "string_value": value}, | ||
int: {"name": name, "long_value": value}, | ||
bool: {"name": name, "bool_value": str(value)}, | ||
float: {"name": name, "float_value": value}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fairly complicated, and gives us names for val that are different. Is it necessary? Why not just include the value and assume that we will worry about type elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type will inherited by the table, so that sql can work properly based on the value type. Currently I don't see a strong use case to do data manipulation for the hyper param values. I could also convert this to a simple key-str(value) dict in a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG.
official/utils/logs/logger.py
Outdated
return type_check.get(type(value), | ||
{"name": name, "string_value": str(value)}) | ||
if hyperparams: | ||
run_info["hyperparameter"] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: regardless of name we choose, should probably be pluralized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple minor suggestions.
official/resnet/resnet_run_loop.py
Outdated
@@ -350,7 +350,7 @@ def validate_batch_size_for_multi_gpu(batch_size): | |||
raise ValueError(err) | |||
|
|||
|
|||
def resnet_main(flags_obj, model_function, input_function, shape=None): | |||
def resnet_main(flags_obj, model_function, input_function, dataset, shape=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we call this dataset_name
for clairity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
official/resnet/resnet_run_loop.py
Outdated
@@ -402,8 +404,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): | |||
'dtype': flags_core.get_tf_dtype(flags_obj) | |||
}) | |||
|
|||
hyperparams = { | |||
'batch_size': flags_obj.batch_size, | |||
'dtype': flags_obj.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be flags_core.get_tf_dtype(flags_obj).name
That way we're logging the official name rather than whatever abbreviation we choose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
official/resnet/resnet_run_loop.py
Outdated
'resnet_size': flags_obj.resnet_size, | ||
'synthetic_data': flags_obj.use_synthetic_data, | ||
'train_epochs': flags_obj.train_epochs, | ||
'version': flags_obj.version, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log this as resnet_version for clarity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Done.
Addressed the comment for data schema rename. |
Schema updated for the bigquery table, will rerun the backfill for the all the existing data. |
str: {"name": name, "string_value": value}, | ||
int: {"name": name, "long_value": value}, | ||
bool: {"name": name, "bool_value": str(value)}, | ||
float: {"name": name, "float_value": value}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG.
* Add dataset info and hyper parameter logging for benchmark. * Address review comments. * Address the view comment for data schema name. * Fix test cases. * Lint fix.
…ow#4152) * Add dataset info and hyper parameter logging for benchmark. * Address review comments. * Address the view comment for data schema name. * Fix test cases. * Lint fix.
Probably will hit a merge conflict with Taylor's flag change.