Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d626b90

Browse files
authored
Fix/log ex per sec (tensorflow#4360)
* Using BenchmarkLogger * Using BenchmarkLogger * Fixing tests * Linting fixes. * Adding comments * Moving mock logger * Moving mock logger * Glinting * Responding to CR * Reverting assertEmpty
1 parent 023fc2b commit d626b90

File tree

6 files changed

+129
-90
lines changed

6 files changed

+129
-90
lines changed

official/utils/logs/hooks.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from __future__ import division
2121
from __future__ import print_function
2222

23-
import tensorflow as tf
23+
import tensorflow as tf # pylint: disable=g-bad-import-order
24+
25+
from official.utils.logs import logger
2426

2527

2628
class ExamplesPerSecondHook(tf.train.SessionRunHook):
@@ -36,7 +38,8 @@ def __init__(self,
3638
batch_size,
3739
every_n_steps=None,
3840
every_n_secs=None,
39-
warm_steps=0):
41+
warm_steps=0,
42+
metric_logger=None):
4043
"""Initializer for ExamplesPerSecondHook.
4144
4245
Args:
@@ -48,15 +51,20 @@ def __init__(self,
4851
warm_steps: The number of steps to be skipped before logging and running
4952
average calculation. warm_steps steps refers to global steps across all
5053
workers, not on each worker
54+
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
55+
hook should use to write the log. If None, BaseBenchmarkLogger will
56+
be used.
5157
5258
Raises:
5359
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
5460
both are set.
5561
"""
5662

5763
if (every_n_steps is None) == (every_n_secs is None):
58-
raise ValueError('exactly one of every_n_steps'
59-
' and every_n_secs should be provided.')
64+
raise ValueError("exactly one of every_n_steps"
65+
" and every_n_secs should be provided.")
66+
67+
self._logger = metric_logger or logger.BaseBenchmarkLogger()
6068

6169
self._timer = tf.train.SecondOrStepTimer(
6270
every_steps=every_n_steps, every_secs=every_n_secs)
@@ -71,7 +79,7 @@ def begin(self):
7179
self._global_step_tensor = tf.train.get_global_step()
7280
if self._global_step_tensor is None:
7381
raise RuntimeError(
74-
'Global step should be created to use StepCounterHook.')
82+
"Global step should be created to use StepCounterHook.")
7583

7684
def before_run(self, run_context): # pylint: disable=unused-argument
7785
"""Called before each call to run().
@@ -109,7 +117,11 @@ def after_run(self, run_context, run_values): # pylint: disable=unused-argument
109117
# and training time per batch
110118
current_examples_per_sec = self._batch_size * (
111119
elapsed_steps / elapsed_time)
112-
# Current examples/sec followed by average examples/sec
113-
tf.logging.info('Batch [%g]: current exp/sec = %g, average exp/sec = '
114-
'%g', self._total_steps, current_examples_per_sec,
115-
average_examples_per_sec)
120+
121+
self._logger.log_metric(
122+
"average_examples_per_sec", average_examples_per_sec,
123+
global_step=global_step)
124+
125+
self._logger.log_metric(
126+
"current_examples_per_sec", current_examples_per_sec,
127+
global_step=global_step)

official/utils/logs/hooks_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def get_examples_per_second_hook(every_n_steps=100,
119119
Returns a ProfilerHook that writes out timelines that can be loaded into
120120
profiling tools like chrome://tracing.
121121
"""
122-
return hooks.ExamplesPerSecondHook(every_n_steps=every_n_steps,
123-
batch_size=batch_size,
124-
warm_steps=warm_steps)
122+
return hooks.ExamplesPerSecondHook(
123+
batch_size=batch_size, every_n_steps=every_n_steps,
124+
warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
125125

126126

127127
def get_logging_metric_hook(tensors_to_log=None,

official/utils/logs/hooks_test.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,77 +25,74 @@
2525
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
2626

2727
from official.utils.logs import hooks
28+
from official.utils.testing import mock_lib
2829

2930

30-
tf.logging.set_verbosity(tf.logging.ERROR)
31+
tf.logging.set_verbosity(tf.logging.DEBUG)
3132

3233

3334
class ExamplesPerSecondHookTest(tf.test.TestCase):
3435
"""Tests for the ExamplesPerSecondHook."""
3536

3637
def setUp(self):
3738
"""Mock out logging calls to verify if correct info is being monitored."""
38-
self._actual_log = tf.logging.info
39-
self.logged_message = None
40-
41-
def mock_log(*args, **kwargs):
42-
self.logged_message = args
43-
self._actual_log(*args, **kwargs)
44-
45-
tf.logging.info = mock_log
39+
self._logger = mock_lib.MockBenchmarkLogger()
4640

4741
self.graph = tf.Graph()
4842
with self.graph.as_default():
4943
self.global_step = tf.train.get_or_create_global_step()
5044
self.train_op = tf.assign_add(self.global_step, 1)
5145

52-
def tearDown(self):
53-
tf.logging.info = self._actual_log
54-
5546
def test_raise_in_both_secs_and_steps(self):
5647
with self.assertRaises(ValueError):
5748
hooks.ExamplesPerSecondHook(
5849
batch_size=256,
5950
every_n_steps=10,
60-
every_n_secs=20)
51+
every_n_secs=20,
52+
metric_logger=self._logger)
6153

6254
def test_raise_in_none_secs_and_steps(self):
6355
with self.assertRaises(ValueError):
6456
hooks.ExamplesPerSecondHook(
6557
batch_size=256,
6658
every_n_steps=None,
67-
every_n_secs=None)
59+
every_n_secs=None,
60+
metric_logger=self._logger)
6861

6962
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
7063
hook = hooks.ExamplesPerSecondHook(
7164
batch_size=256,
7265
every_n_steps=every_n_steps,
73-
warm_steps=warm_steps)
66+
warm_steps=warm_steps,
67+
metric_logger=self._logger)
7468
hook.begin()
7569
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
7670
sess.run(tf.global_variables_initializer())
7771

78-
self.logged_message = ''
7972
for _ in range(every_n_steps):
8073
mon_sess.run(self.train_op)
81-
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
74+
# Nothing should be in the list yet
75+
self.assertFalse(self._logger.logged_metric)
8276

8377
mon_sess.run(self.train_op)
8478
global_step_val = sess.run(self.global_step)
85-
# assertNotRegexpMatches is not supported by python 3.1 and later
79+
8680
if global_step_val > warm_steps:
87-
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
81+
self._assert_metrics()
8882
else:
89-
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
83+
# Nothing should be in the list yet
84+
self.assertFalse(self._logger.logged_metric)
9085

9186
# Add additional run to verify proper reset when called multiple times.
92-
self.logged_message = ''
87+
prev_log_len = len(self._logger.logged_metric)
9388
mon_sess.run(self.train_op)
9489
global_step_val = sess.run(self.global_step)
9590
if every_n_steps == 1 and global_step_val > warm_steps:
96-
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
91+
# Each time, we log two additional metrics. Did exactly 2 get added?
92+
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
9793
else:
98-
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
94+
# No change in the size of the metric list.
95+
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
9996

10097
hook.end(sess)
10198

@@ -119,19 +116,19 @@ def _validate_log_every_n_secs(self, sess, every_n_secs):
119116
hook = hooks.ExamplesPerSecondHook(
120117
batch_size=256,
121118
every_n_steps=None,
122-
every_n_secs=every_n_secs)
119+
every_n_secs=every_n_secs,
120+
metric_logger=self._logger)
123121
hook.begin()
124122
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
125123
sess.run(tf.global_variables_initializer())
126124

127-
self.logged_message = ''
128125
mon_sess.run(self.train_op)
129-
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
126+
# Nothing should be in the list yet
127+
self.assertFalse(self._logger.logged_metric)
130128
time.sleep(every_n_secs)
131129

132-
self.logged_message = ''
133130
mon_sess.run(self.train_op)
134-
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
131+
self._assert_metrics()
135132

136133
hook.end(sess)
137134

@@ -143,6 +140,11 @@ def test_examples_per_sec_every_5_secs(self):
143140
with self.graph.as_default(), tf.Session() as sess:
144141
self._validate_log_every_n_secs(sess, 5)
145142

143+
def _assert_metrics(self):
144+
metrics = self._logger.logged_metric
145+
self.assertEqual(metrics[-2]["name"], "average_examples_per_sec")
146+
self.assertEqual(metrics[-1]["name"], "current_examples_per_sec")
147+
146148

147-
if __name__ == '__main__':
149+
if __name__ == "__main__":
148150
tf.test.main()

official/utils/logs/logger.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,20 @@
4747

4848

4949
def config_benchmark_logger(flag_obj=None):
50-
"""Config the global benchmark logger"""
50+
"""Config the global benchmark logger."""
5151
_logger_lock.acquire()
5252
try:
5353
global _benchmark_logger
5454
if not flag_obj:
5555
flag_obj = FLAGS
5656

57-
if (not hasattr(flag_obj, 'benchmark_logger_type') or
58-
flag_obj.benchmark_logger_type == 'BaseBenchmarkLogger'):
57+
if (not hasattr(flag_obj, "benchmark_logger_type") or
58+
flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"):
5959
_benchmark_logger = BaseBenchmarkLogger()
60-
elif flag_obj.benchmark_logger_type == 'BenchmarkFileLogger':
60+
elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger":
6161
_benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
62-
elif flag_obj.benchmark_logger_type == 'BenchmarkBigQueryLogger':
63-
from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top
62+
elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger":
63+
from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top
6464
bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project)
6565
_benchmark_logger = BenchmarkBigQueryLogger(
6666
bigquery_uploader=bq_uploader,
@@ -69,8 +69,8 @@ def config_benchmark_logger(flag_obj=None):
6969
bigquery_metric_table=flag_obj.bigquery_metric_table,
7070
run_id=str(uuid.uuid4()))
7171
else:
72-
raise ValueError('Unrecognized benchmark_logger_type: %s',
73-
flag_obj.benchmark_logger_type)
72+
raise ValueError("Unrecognized benchmark_logger_type: %s"
73+
% flag_obj.benchmark_logger_type)
7474

7575
finally:
7676
_logger_lock.release()
@@ -247,6 +247,7 @@ def log_run_info(self, model_name, dataset_name, run_params):
247247
self._run_id,
248248
run_info))
249249

250+
250251
def _gather_run_info(model_name, dataset_name, run_params):
251252
"""Collect the benchmark run information for the local environment."""
252253
run_info = {
@@ -303,6 +304,7 @@ def process_param(name, value):
303304
run_info["run_parameters"] = [
304305
process_param(k, v) for k, v in sorted(run_params.items())]
305306

307+
306308
def _collect_tensorflow_environment_variables(run_info):
307309
run_info["tensorflow_environment_variables"] = [
308310
{"name": k, "value": v}

0 commit comments

Comments
 (0)