25
25
from tensorflow .python .training import monitored_session # pylint: disable=g-bad-import-order
26
26
27
27
from official .utils .logs import hooks
28
+ from official .utils .testing import mock_lib
28
29
29
30
30
- tf .logging .set_verbosity (tf .logging .ERROR )
31
+ tf .logging .set_verbosity (tf .logging .DEBUG )
31
32
32
33
33
34
class ExamplesPerSecondHookTest (tf .test .TestCase ):
34
35
"""Tests for the ExamplesPerSecondHook."""
35
36
36
37
def setUp (self ):
37
38
"""Mock out logging calls to verify if correct info is being monitored."""
38
- self ._actual_log = tf .logging .info
39
- self .logged_message = None
40
-
41
- def mock_log (* args , ** kwargs ):
42
- self .logged_message = args
43
- self ._actual_log (* args , ** kwargs )
44
-
45
- tf .logging .info = mock_log
39
+ self ._logger = mock_lib .MockBenchmarkLogger ()
46
40
47
41
self .graph = tf .Graph ()
48
42
with self .graph .as_default ():
49
43
self .global_step = tf .train .get_or_create_global_step ()
50
44
self .train_op = tf .assign_add (self .global_step , 1 )
51
45
52
- def tearDown (self ):
53
- tf .logging .info = self ._actual_log
54
-
55
46
def test_raise_in_both_secs_and_steps (self ):
56
47
with self .assertRaises (ValueError ):
57
48
hooks .ExamplesPerSecondHook (
58
49
batch_size = 256 ,
59
50
every_n_steps = 10 ,
60
- every_n_secs = 20 )
51
+ every_n_secs = 20 ,
52
+ metric_logger = self ._logger )
61
53
62
54
def test_raise_in_none_secs_and_steps (self ):
63
55
with self .assertRaises (ValueError ):
64
56
hooks .ExamplesPerSecondHook (
65
57
batch_size = 256 ,
66
58
every_n_steps = None ,
67
- every_n_secs = None )
59
+ every_n_secs = None ,
60
+ metric_logger = self ._logger )
68
61
69
62
def _validate_log_every_n_steps (self , sess , every_n_steps , warm_steps ):
70
63
hook = hooks .ExamplesPerSecondHook (
71
64
batch_size = 256 ,
72
65
every_n_steps = every_n_steps ,
73
- warm_steps = warm_steps )
66
+ warm_steps = warm_steps ,
67
+ metric_logger = self ._logger )
74
68
hook .begin ()
75
69
mon_sess = monitored_session ._HookedSession (sess , [hook ]) # pylint: disable=protected-access
76
70
sess .run (tf .global_variables_initializer ())
77
71
78
- self .logged_message = ''
79
72
for _ in range (every_n_steps ):
80
73
mon_sess .run (self .train_op )
81
- self .assertEqual (str (self .logged_message ).find ('exp/sec' ), - 1 )
74
+ # Nothing should be in the list yet
75
+ self .assertFalse (self ._logger .logged_metric )
82
76
83
77
mon_sess .run (self .train_op )
84
78
global_step_val = sess .run (self .global_step )
85
- # assertNotRegexpMatches is not supported by python 3.1 and later
79
+
86
80
if global_step_val > warm_steps :
87
- self .assertRegexpMatches ( str ( self . logged_message ), 'exp/sec' )
81
+ self ._assert_metrics ( )
88
82
else :
89
- self .assertEqual (str (self .logged_message ).find ('exp/sec' ), - 1 )
83
+ # Nothing should be in the list yet
84
+ self .assertFalse (self ._logger .logged_metric )
90
85
91
86
# Add additional run to verify proper reset when called multiple times.
92
- self . logged_message = ''
87
+ prev_log_len = len ( self . _logger . logged_metric )
93
88
mon_sess .run (self .train_op )
94
89
global_step_val = sess .run (self .global_step )
95
90
if every_n_steps == 1 and global_step_val > warm_steps :
96
- self .assertRegexpMatches (str (self .logged_message ), 'exp/sec' )
91
+ # Each time, we log two additional metrics. Did exactly 2 get added?
92
+ self .assertEqual (len (self ._logger .logged_metric ), prev_log_len + 2 )
97
93
else :
98
- self .assertEqual (str (self .logged_message ).find ('exp/sec' ), - 1 )
94
+ # No change in the size of the metric list.
95
+ self .assertEqual (len (self ._logger .logged_metric ), prev_log_len )
99
96
100
97
hook .end (sess )
101
98
@@ -119,19 +116,19 @@ def _validate_log_every_n_secs(self, sess, every_n_secs):
119
116
hook = hooks .ExamplesPerSecondHook (
120
117
batch_size = 256 ,
121
118
every_n_steps = None ,
122
- every_n_secs = every_n_secs )
119
+ every_n_secs = every_n_secs ,
120
+ metric_logger = self ._logger )
123
121
hook .begin ()
124
122
mon_sess = monitored_session ._HookedSession (sess , [hook ]) # pylint: disable=protected-access
125
123
sess .run (tf .global_variables_initializer ())
126
124
127
- self .logged_message = ''
128
125
mon_sess .run (self .train_op )
129
- self .assertEqual (str (self .logged_message ).find ('exp/sec' ), - 1 )
126
+ # Nothing should be in the list yet
127
+ self .assertFalse (self ._logger .logged_metric )
130
128
time .sleep (every_n_secs )
131
129
132
- self .logged_message = ''
133
130
mon_sess .run (self .train_op )
134
- self .assertRegexpMatches ( str ( self . logged_message ), 'exp/sec' )
131
+ self ._assert_metrics ( )
135
132
136
133
hook .end (sess )
137
134
@@ -143,6 +140,11 @@ def test_examples_per_sec_every_5_secs(self):
143
140
with self .graph .as_default (), tf .Session () as sess :
144
141
self ._validate_log_every_n_secs (sess , 5 )
145
142
143
+ def _assert_metrics (self ):
144
+ metrics = self ._logger .logged_metric
145
+ self .assertEqual (metrics [- 2 ]["name" ], "average_examples_per_sec" )
146
+ self .assertEqual (metrics [- 1 ]["name" ], "current_examples_per_sec" )
147
+
146
148
147
- if __name__ == ' __main__' :
149
+ if __name__ == " __main__" :
148
150
tf .test .main ()
0 commit comments