Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0270cac

Browse files
authored
Add benchmark logger that does stream upload to bigquery. (tensorflow#4210)
* Move the benchmark_uploader to new location. * Update benchmark logger to streaming upload. * Fix lint and unit test error. * delint. * Update the benchmark uploader test. Skip the import of benchmark_uploader when bigquery is not installed. * Merge the 2 classes of benchmark uploader into 1. * Address review comments. * delint. * Execute bigquery upload in a separate thread. * Change to use python six.moves for importing. * Address review comments and delint. * Address review comment. Adding comment for potential performance impact for model on CPU. * Fix random failure on py3. * Fix the order of flag saver to avoid the randomness. The test is broken when the benchmark_logger_type is set first, and validated when the benchmark_log_dir is not set yet.
1 parent 80178fc commit 0270cac

File tree

10 files changed

+450
-116
lines changed

10 files changed

+450
-116
lines changed

official/benchmark/__init__.py

Whitespace-only changes.

official/utils/logs/benchmark_uploader.py renamed to official/benchmark/benchmark_uploader.py

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,19 @@
2525
from __future__ import print_function
2626

2727
import json
28-
import os
29-
import sys
30-
import uuid
3128

3229
from google.cloud import bigquery
3330

34-
# pylint: disable=g-bad-import-order
35-
from absl import app as absl_app
36-
from absl import flags
3731
import tensorflow as tf
38-
# pylint: enable=g-bad-import-order
39-
40-
from official.utils.flags import core as flags_core
41-
from official.utils.logs import logger
4232

4333

4434
class BigQueryUploader(object):
45-
"""Upload the benchmark and metric info to BigQuery."""
35+
"""Upload the benchmark and metric info from JSON input to BigQuery. """
4636

47-
def __init__(self, logging_dir, gcp_project=None, credentials=None):
37+
def __init__(self, gcp_project=None, credentials=None):
4838
"""Initialized BigQueryUploader with proper setting.
4939
5040
Args:
51-
logging_dir: string, logging directory that contains the benchmark log.
5241
gcp_project: string, the name of the GCP project that the log will be
5342
uploaded to. The default project name will be detected from local
5443
environment if no value is provided.
@@ -58,11 +47,11 @@ def __init__(self, logging_dir, gcp_project=None, credentials=None):
5847
google.oauth2.service_account.Credentials to load credential from local
5948
file for the case that the test is run out side of GCP.
6049
"""
61-
self._logging_dir = logging_dir
6250
self._bq_client = bigquery.Client(
6351
project=gcp_project, credentials=credentials)
6452

65-
def upload_benchmark_run(self, dataset_name, table_name, run_id):
53+
def upload_benchmark_run_json(
54+
self, dataset_name, table_name, run_id, run_json):
6655
"""Upload benchmark run information to Bigquery.
6756
6857
Args:
@@ -72,19 +61,13 @@ def upload_benchmark_run(self, dataset_name, table_name, run_id):
7261
the data will be uploaded.
7362
run_id: string, a unique ID that will be attached to the data, usually
7463
this is a UUID4 format.
64+
run_json: dict, the JSON data that contains the benchmark run info.
7565
"""
76-
expected_file = os.path.join(
77-
self._logging_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
78-
with tf.gfile.GFile(expected_file) as f:
79-
benchmark_json = json.load(f)
80-
benchmark_json["model_id"] = run_id
81-
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
82-
errors = self._bq_client.insert_rows_json(table_ref, [benchmark_json])
83-
if errors:
84-
tf.logging.error(
85-
"Failed to upload benchmark info to bigquery: {}".format(errors))
86-
87-
def upload_metric(self, dataset_name, table_name, run_id):
66+
run_json["model_id"] = run_id
67+
self._upload_json(dataset_name, table_name, [run_json])
68+
69+
def upload_benchmark_metric_json(
70+
self, dataset_name, table_name, run_id, metric_json_list):
8871
"""Upload metric information to Bigquery.
8972
9073
Args:
@@ -95,39 +78,57 @@ def upload_metric(self, dataset_name, table_name, run_id):
9578
benchmark_run table.
9679
run_id: string, a unique ID that will be attached to the data, usually
9780
this is a UUID4 format. This should be the same as the benchmark run_id.
81+
metric_json_list: list, a list of JSON object that record the metric info.
82+
"""
83+
for m in metric_json_list:
84+
m["run_id"] = run_id
85+
self._upload_json(dataset_name, table_name, metric_json_list)
86+
87+
def upload_benchmark_run_file(
88+
self, dataset_name, table_name, run_id, run_json_file):
89+
"""Upload benchmark run information to Bigquery from input json file.
90+
91+
Args:
92+
dataset_name: string, the name of bigquery dataset where the data will be
93+
uploaded.
94+
table_name: string, the name of bigquery table under the dataset where
95+
the data will be uploaded.
96+
run_id: string, a unique ID that will be attached to the data, usually
97+
this is a UUID4 format.
98+
run_json_file: string, the file path that contains the run JSON data.
99+
"""
100+
with tf.gfile.GFile(run_json_file) as f:
101+
benchmark_json = json.load(f)
102+
self.upload_benchmark_run_json(
103+
dataset_name, table_name, run_id, benchmark_json)
104+
105+
def upload_metric_file(
106+
self, dataset_name, table_name, run_id, metric_json_file):
107+
"""Upload metric information to Bigquery from input json file.
108+
109+
Args:
110+
dataset_name: string, the name of bigquery dataset where the data will be
111+
uploaded.
112+
table_name: string, the name of bigquery table under the dataset where
113+
the metric data will be uploaded. This is different from the
114+
benchmark_run table.
115+
run_id: string, a unique ID that will be attached to the data, usually
116+
this is a UUID4 format. This should be the same as the benchmark run_id.
117+
metric_json_file: string, the file path that contains the metric JSON
118+
data.
98119
"""
99-
expected_file = os.path.join(
100-
self._logging_dir, logger.METRIC_LOG_FILE_NAME)
101-
with tf.gfile.GFile(expected_file) as f:
102-
lines = f.readlines()
120+
with tf.gfile.GFile(metric_json_file) as f:
103121
metrics = []
104-
for line in filter(lambda l: l.strip(), lines):
105-
metric = json.loads(line)
106-
metric["run_id"] = run_id
107-
metrics.append(metric)
108-
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
109-
errors = self._bq_client.insert_rows_json(table_ref, metrics)
110-
if errors:
111-
tf.logging.error(
112-
"Failed to upload benchmark info to bigquery: {}".format(errors))
113-
114-
115-
def main(_):
116-
if not flags.FLAGS.benchmark_log_dir:
117-
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
118-
sys.exit(1)
119-
120-
uploader = BigQueryUploader(
121-
flags.FLAGS.benchmark_log_dir,
122-
gcp_project=flags.FLAGS.gcp_project)
123-
run_id = str(uuid.uuid4())
124-
uploader.upload_benchmark_run(
125-
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id)
126-
uploader.upload_metric(
127-
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id)
128-
129-
130-
if __name__ == "__main__":
131-
flags_core.define_benchmark()
132-
flags.adopt_module_key_flags(flags_core)
133-
absl_app.run(main=main)
122+
for line in f:
123+
metrics.append(json.loads(line.strip()))
124+
self.upload_benchmark_metric_json(
125+
dataset_name, table_name, run_id, metrics)
126+
127+
def _upload_json(self, dataset_name, table_name, json_list):
128+
# Find the unique table reference based on dataset and table name, so that
129+
# the data can be inserted to it.
130+
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
131+
errors = self._bq_client.insert_rows_json(table_ref, json_list)
132+
if errors:
133+
tf.logging.error(
134+
"Failed to upload benchmark info to bigquery: {}".format(errors))
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Binary to upload benchmark generated by BenchmarkLogger to remote repo.
17+
18+
This library require google cloud bigquery lib as dependency, which can be
19+
installed with:
20+
> pip install --upgrade google-cloud-bigquery
21+
"""
22+
23+
from __future__ import absolute_import
24+
from __future__ import division
25+
from __future__ import print_function
26+
27+
import os
28+
import sys
29+
import uuid
30+
31+
from absl import app as absl_app
32+
from absl import flags
33+
34+
from official.benchmark import benchmark_uploader
35+
from official.utils.flags import core as flags_core
36+
from official.utils.logs import logger
37+
38+
def main(_):
39+
if not flags.FLAGS.benchmark_log_dir:
40+
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
41+
sys.exit(1)
42+
43+
uploader = benchmark_uploader.BigQueryUploader(
44+
gcp_project=flags.FLAGS.gcp_project)
45+
run_id = str(uuid.uuid4())
46+
run_json_file = os.path.join(
47+
flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
48+
metric_json_file = os.path.join(
49+
flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME)
50+
51+
uploader.upload_benchmark_run_file(
52+
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id,
53+
run_json_file)
54+
uploader.upload_metric_file(
55+
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
56+
metric_json_file)
57+
58+
59+
if __name__ == "__main__":
60+
flags_core.define_benchmark()
61+
flags.adopt_module_key_flags(flags_core)
62+
absl_app.run(main=main)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for benchmark_uploader."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import json
23+
import os
24+
import tempfile
25+
import unittest
26+
from mock import MagicMock
27+
from mock import patch
28+
29+
import tensorflow as tf # pylint: disable=g-bad-import-order
30+
31+
try:
32+
from google.cloud import bigquery
33+
from official.benchmark import benchmark_uploader
34+
except ImportError:
35+
bigquery = None
36+
benchmark_uploader = None
37+
38+
39+
@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
40+
class BigQueryUploaderTest(tf.test.TestCase):
41+
42+
@patch.object(bigquery, 'Client')
43+
def setUp(self, mock_bigquery):
44+
self.mock_client = mock_bigquery.return_value
45+
self.mock_dataset = MagicMock(name="dataset")
46+
self.mock_table = MagicMock(name="table")
47+
self.mock_client.dataset.return_value = self.mock_dataset
48+
self.mock_dataset.table.return_value = self.mock_table
49+
self.mock_client.insert_rows_json.return_value = []
50+
51+
self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
52+
self.benchmark_uploader._bq_client = self.mock_client
53+
54+
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
55+
with open(os.path.join(self.log_dir, 'metric.log'), 'a') as f:
56+
json.dump({'name': 'accuracy', 'value': 1.0}, f)
57+
f.write("\n")
58+
json.dump({'name': 'loss', 'value': 0.5}, f)
59+
f.write("\n")
60+
with open(os.path.join(self.log_dir, 'run.log'), 'w') as f:
61+
json.dump({'model_name': 'value'}, f)
62+
63+
def tearDown(self):
64+
tf.gfile.DeleteRecursively(self.get_temp_dir())
65+
66+
def test_upload_benchmark_run_json(self):
67+
self.benchmark_uploader.upload_benchmark_run_json(
68+
'dataset', 'table', 'run_id', {'model_name': 'value'})
69+
70+
self.mock_client.insert_rows_json.assert_called_once_with(
71+
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
72+
73+
def test_upload_benchmark_metric_json(self):
74+
metric_json_list = [
75+
{'name': 'accuracy', 'value': 1.0},
76+
{'name': 'loss', 'value': 0.5}
77+
]
78+
expected_params = [
79+
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
80+
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
81+
]
82+
self.benchmark_uploader.upload_benchmark_metric_json(
83+
'dataset', 'table', 'run_id', metric_json_list)
84+
self.mock_client.insert_rows_json.assert_called_once_with(
85+
self.mock_table, expected_params)
86+
87+
def test_upload_benchmark_run_file(self):
88+
self.benchmark_uploader.upload_benchmark_run_file(
89+
'dataset', 'table', 'run_id', os.path.join(self.log_dir, 'run.log'))
90+
91+
self.mock_client.insert_rows_json.assert_called_once_with(
92+
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
93+
94+
def test_upload_metric_file(self):
95+
self.benchmark_uploader.upload_metric_file(
96+
'dataset', 'table', 'run_id',
97+
os.path.join(self.log_dir, 'metric.log'))
98+
expected_params = [
99+
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
100+
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
101+
]
102+
self.mock_client.insert_rows_json.assert_called_once_with(
103+
self.mock_table, expected_params)
104+
105+
106+
if __name__ == '__main__':
107+
tf.test.main()

official/resnet/resnet_run_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,12 @@ def resnet_main(
395395
'synthetic_data': flags_obj.use_synthetic_data,
396396
'train_epochs': flags_obj.train_epochs,
397397
}
398-
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
398+
benchmark_logger = logger.config_benchmark_logger(flags_obj)
399399
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
400400

401401
train_hooks = hooks_helper.get_train_hooks(
402402
flags_obj.hooks,
403-
batch_size=flags_obj.batch_size,
404-
benchmark_log_dir=flags_obj.benchmark_log_dir)
403+
batch_size=flags_obj.batch_size)
405404

406405
def input_fn_train():
407406
return input_function(

official/utils/flags/_benchmark.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
3636

3737
key_flags = []
3838

39+
flags.DEFINE_enum(
40+
name="benchmark_logger_type", default="BaseBenchmarkLogger",
41+
enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger",
42+
"BenchmarkBigQueryLogger"],
43+
help=help_wrap("The type of benchmark logger to use. Defaults to using "
44+
"BaseBenchmarkLogger which logs to STDOUT. Different "
45+
"loggers will require other flags to be able to work."))
46+
3947
if benchmark_log_dir:
4048
flags.DEFINE_string(
4149
name="benchmark_log_dir", short_name="bld", default=None,
@@ -64,4 +72,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
6472
help=help_wrap("The Bigquery table name where the benchmark metric "
6573
"information will be uploaded."))
6674

67-
return key_flags
75+
@flags.multi_flags_validator(
76+
["benchmark_logger_type", "benchmark_log_dir"],
77+
message="--benchmark_logger_type=BenchmarkFileLogger will require "
78+
"--benchmark_log_dir being set")
79+
def _check_benchmark_log_dir(flags_dict):
80+
benchmark_logger_type = flags_dict["benchmark_logger_type"]
81+
if benchmark_logger_type == "BenchmarkFileLogger":
82+
return flags_dict["benchmark_log_dir"]
83+
return True
84+
85+
return key_flags

0 commit comments

Comments
 (0)