Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit febaae9

Browse files
rxsangtensorflower-gardener
authored andcommitted
Some improvements and bug fixes to Controller:
1. Fix a bug that checkpoint will be saved after every training loop. 2. Only create the training and eval summaries writers if the corresponding `train_fn` and `eval_fn` are passed. 3. Flush the summary writers after training and eval finish. 4. Add a Controller test. Also make sure there is no evaluation happening in Resnet CTL example if `skip_eval=True`. PiperOrigin-RevId: 301489305
1 parent 101f1f0 commit febaae9

File tree

4 files changed

+306
-36
lines changed

4 files changed

+306
-36
lines changed

official/staging/training/controller.py

+37-32
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def __init__(
7878
eval_summary_dir: The directory to write eval summaries. If None, it will
7979
be set to `summary_dir`.
8080
eval_steps: Number of steps to run evaluation.
81-
eval_interval: Step interval for evaluation. If None, will skip
82-
evaluation. Note that evaluation only happens outside the training loop,
83-
which the loop iteration is specify by `steps_per_loop` parameter.
81+
eval_interval: Step interval for evaluation. If None, will skip evaluation
82+
in the middle of training. Note that evaluation only happens outside the
83+
training loop, which the loop iteration is specify by `steps_per_loop`
84+
parameter.
8485
8586
Raises:
8687
ValueError: If both `train_fn` and `eval_fn` are None.
@@ -111,35 +112,41 @@ def __init__(
111112
self.train_fn = train_fn
112113
self.eval_fn = eval_fn
113114
self.global_step = global_step
114-
115-
self.train_steps = train_steps
116-
117-
self.steps_per_loop = steps_per_loop
118-
119-
self.summary_dir = summary_dir or checkpoint_manager.directory
120115
self.checkpoint_manager = checkpoint_manager
121116

122-
self.summary_interval = summary_interval
123-
summary_writer = tf.summary.create_file_writer(
124-
self.summary_dir) if self.summary_interval else None
125-
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
126-
# maximum customizability.
127-
self.summary_manager = utils.SummaryManager(
128-
summary_writer,
129-
tf.summary.scalar,
130-
global_step=self.global_step,
131-
summary_interval=self.summary_interval)
117+
if self.train_fn is not None:
118+
self.train_steps = train_steps
119+
self.steps_per_loop = steps_per_loop
120+
self.summary_dir = summary_dir or checkpoint_manager.directory
121+
122+
self.summary_interval = summary_interval
123+
summary_writer = tf.summary.create_file_writer(
124+
self.summary_dir) if self.summary_interval else None
125+
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
126+
# maximum customizability.
127+
self.summary_manager = utils.SummaryManager(
128+
summary_writer,
129+
tf.summary.scalar,
130+
global_step=self.global_step,
131+
summary_interval=self.summary_interval)
132+
133+
if self.eval_fn is not None:
134+
eval_summary_dir = eval_summary_dir or self.summary_dir
135+
eval_summary_writer = tf.summary.create_file_writer(
136+
eval_summary_dir) if eval_summary_dir else None
137+
self.eval_summary_manager = utils.SummaryManager(
138+
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
139+
140+
self.eval_steps = eval_steps
141+
self.eval_interval = eval_interval
142+
143+
# Create and initialize the interval triggers.
144+
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
145+
self.global_step.numpy())
146+
132147
if self.global_step:
133148
tf.summary.experimental.set_step(self.global_step)
134149

135-
self.eval_summary_dir = eval_summary_dir or self.summary_dir
136-
eval_summary_writer = tf.summary.create_file_writer(self.eval_summary_dir)
137-
self.eval_summary_manager = utils.SummaryManager(
138-
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
139-
140-
self.eval_steps = eval_steps
141-
self.eval_interval = eval_interval
142-
143150
# Restore Model if needed.
144151
if self.checkpoint_manager is not None:
145152
model_restored = self._restore_model()
@@ -150,10 +157,6 @@ def __init__(
150157
checkpoint_number=self.global_step)
151158
logging.info("Saved checkpoins in %s", ckpt_path)
152159

153-
# Create and initialize the interval triggers.
154-
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
155-
self.global_step.numpy())
156-
157160
def _restore_model(self, checkpoint_path=None):
158161
"""Restore or initialize the model.
159162
@@ -186,11 +189,12 @@ def _evaluate_once(self, current_step):
186189
self._log_info(info)
187190

188191
self.eval_summary_manager.write_summaries(eval_outputs)
192+
self.eval_summary_manager.flush()
189193

190194
def _maybe_save_checkpoints(self, current_step, force_trigger=False):
191195
if self.checkpoint_manager.checkpoint_interval:
192196
ckpt_path = self.checkpoint_manager.save(
193-
checkpoint_number=current_step, check_interval=force_trigger)
197+
checkpoint_number=current_step, check_interval=not force_trigger)
194198
if ckpt_path is not None:
195199
logging.info("Saved checkpoins in %s", ckpt_path)
196200

@@ -265,6 +269,7 @@ def train(self, evaluate=True):
265269
self._maybe_evaluate(current_step)
266270

267271
self.summary_manager.write_summaries(train_outputs, always_write=True)
272+
self.summary_manager.flush()
268273
self._maybe_save_checkpoints(current_step, force_trigger=True)
269274
if evaluate:
270275
self._maybe_evaluate(current_step, force_trigger=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for official.staging.training.controller."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
23+
from absl.testing import parameterized
24+
import numpy as np
25+
import tensorflow as tf
26+
27+
from tensorflow.python.distribute import combinations
28+
from tensorflow.python.distribute import strategy_combinations
29+
from official.staging.training import controller
30+
from official.staging.training import standard_runnable
31+
32+
33+
def all_strategy_combinations():
34+
"""Gets combinations of distribution strategies."""
35+
return combinations.combine(
36+
strategy=[
37+
strategy_combinations.one_device_strategy,
38+
strategy_combinations.tpu_strategy,
39+
strategy_combinations.one_device_strategy_gpu,
40+
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
41+
],
42+
mode="eager",
43+
)
44+
45+
46+
def create_model():
47+
x = tf.keras.layers.Input(shape=(3,), name="input")
48+
y = tf.keras.layers.Dense(4, name="dense")(x)
49+
model = tf.keras.Model(x, y)
50+
return model
51+
52+
53+
def summaries_with_matching_keyword(keyword, summary_dir):
54+
"""Yields summary protos matching given keyword from event file."""
55+
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
56+
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
57+
if event.summary is not None:
58+
for value in event.summary.value:
59+
if keyword in value.tag:
60+
tf.compat.v1.logging.error(event)
61+
yield event.summary
62+
63+
64+
def check_eventfile_for_keyword(keyword, summary_dir):
65+
"""Checks event files for the keyword."""
66+
return any(summaries_with_matching_keyword(keyword, summary_dir))
67+
68+
69+
def dataset_fn(ctx):
70+
del ctx
71+
inputs = np.zeros((10, 3), dtype=np.float32)
72+
targets = np.zeros((10, 4), dtype=np.float32)
73+
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
74+
dataset = dataset.repeat(100)
75+
dataset = dataset.batch(10, drop_remainder=True)
76+
return dataset
77+
78+
79+
class TestRunnable(standard_runnable.StandardTrainable,
80+
standard_runnable.StandardEvaluable):
81+
"""Implements the training and evaluation APIs for the test model."""
82+
83+
def __init__(self):
84+
standard_runnable.StandardTrainable.__init__(self)
85+
standard_runnable.StandardEvaluable.__init__(self)
86+
self.strategy = tf.distribute.get_strategy()
87+
self.model = create_model()
88+
self.optimizer = tf.keras.optimizers.RMSprop()
89+
self.global_step = self.optimizer.iterations
90+
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
91+
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
92+
93+
def build_train_dataset(self):
94+
return self.strategy.experimental_distribute_datasets_from_function(
95+
dataset_fn)
96+
97+
def train_step(self, iterator):
98+
99+
def _replicated_step(inputs):
100+
"""Replicated training step."""
101+
inputs, targets = inputs
102+
with tf.GradientTape() as tape:
103+
outputs = self.model(inputs)
104+
loss = tf.math.reduce_sum(outputs - targets)
105+
grads = tape.gradient(loss, self.model.variables)
106+
self.optimizer.apply_gradients(zip(grads, self.model.variables))
107+
self.train_loss.update_state(loss)
108+
109+
self.strategy.run(_replicated_step, args=(next(iterator),))
110+
111+
def train_loop_end(self):
112+
return {
113+
"loss": self.train_loss.result(),
114+
}
115+
116+
def build_eval_dataset(self):
117+
return self.strategy.experimental_distribute_datasets_from_function(
118+
dataset_fn)
119+
120+
def eval_begin(self):
121+
self.eval_loss.reset_states()
122+
123+
def eval_step(self, iterator):
124+
125+
def _replicated_step(inputs):
126+
"""Replicated evaluation step."""
127+
inputs, targets = inputs
128+
outputs = self.model(inputs)
129+
loss = tf.math.reduce_sum(outputs - targets)
130+
self.eval_loss.update_state(loss)
131+
132+
self.strategy.run(_replicated_step, args=(next(iterator),))
133+
134+
def eval_end(self):
135+
return {
136+
"eval_loss": self.eval_loss.result(),
137+
}
138+
139+
140+
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
141+
142+
def setUp(self):
143+
super(ControllerTest, self).setUp()
144+
self.model_dir = self.get_temp_dir()
145+
146+
@combinations.generate(all_strategy_combinations())
147+
def test_train_and_evaluate(self, strategy):
148+
with strategy.scope():
149+
test_runnable = TestRunnable()
150+
151+
checkpoint = tf.train.Checkpoint(
152+
model=test_runnable.model, optimizer=test_runnable.optimizer)
153+
checkpoint_manager = tf.train.CheckpointManager(
154+
checkpoint,
155+
self.model_dir,
156+
max_to_keep=None,
157+
step_counter=test_runnable.global_step,
158+
checkpoint_interval=10)
159+
test_controller = controller.Controller(
160+
strategy=strategy,
161+
train_fn=test_runnable.train,
162+
eval_fn=test_runnable.evaluate,
163+
global_step=test_runnable.global_step,
164+
train_steps=10,
165+
steps_per_loop=2,
166+
summary_dir=os.path.join(self.model_dir, "summaries/train"),
167+
summary_interval=2,
168+
checkpoint_manager=checkpoint_manager,
169+
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
170+
eval_steps=2,
171+
eval_interval=5)
172+
test_controller.train(evaluate=True)
173+
174+
# Checkpoints are saved.
175+
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
176+
177+
# Loss and accuracy values should be written into summaries.
178+
self.assertNotEmpty(
179+
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
180+
self.assertTrue(
181+
check_eventfile_for_keyword(
182+
"loss", os.path.join(self.model_dir, "summaries/train")))
183+
self.assertNotEmpty(
184+
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
185+
self.assertTrue(
186+
check_eventfile_for_keyword(
187+
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
188+
189+
@combinations.generate(all_strategy_combinations())
190+
def test_train_only(self, strategy):
191+
with strategy.scope():
192+
test_runnable = TestRunnable()
193+
194+
checkpoint = tf.train.Checkpoint(
195+
model=test_runnable.model, optimizer=test_runnable.optimizer)
196+
checkpoint_manager = tf.train.CheckpointManager(
197+
checkpoint,
198+
self.model_dir,
199+
max_to_keep=None,
200+
step_counter=test_runnable.global_step,
201+
checkpoint_interval=10)
202+
test_controller = controller.Controller(
203+
strategy=strategy,
204+
train_fn=test_runnable.train,
205+
global_step=test_runnable.global_step,
206+
train_steps=10,
207+
steps_per_loop=2,
208+
summary_dir=os.path.join(self.model_dir, "summaries/train"),
209+
summary_interval=2,
210+
checkpoint_manager=checkpoint_manager,
211+
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
212+
)
213+
test_controller.train(evaluate=False)
214+
215+
# Checkpoints are saved.
216+
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
217+
218+
# Only train summaries are written.
219+
self.assertNotEmpty(
220+
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
221+
self.assertTrue(
222+
check_eventfile_for_keyword(
223+
"loss", os.path.join(self.model_dir, "summaries/train")))
224+
self.assertFalse(
225+
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
226+
227+
@combinations.generate(all_strategy_combinations())
228+
def test_evaluate_only(self, strategy):
229+
with strategy.scope():
230+
test_runnable = TestRunnable()
231+
232+
checkpoint = tf.train.Checkpoint(model=test_runnable.model)
233+
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
234+
235+
checkpoint_manager = tf.train.CheckpointManager(
236+
checkpoint,
237+
self.model_dir,
238+
max_to_keep=None,
239+
step_counter=test_runnable.global_step)
240+
test_controller = controller.Controller(
241+
strategy=strategy,
242+
eval_fn=test_runnable.evaluate,
243+
global_step=test_runnable.global_step,
244+
checkpoint_manager=checkpoint_manager,
245+
summary_dir=os.path.join(self.model_dir, "summaries/train"),
246+
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
247+
eval_steps=2,
248+
eval_interval=5)
249+
test_controller.evaluate()
250+
251+
# Only eval summaries are written
252+
self.assertFalse(
253+
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
254+
self.assertNotEmpty(
255+
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
256+
self.assertTrue(
257+
check_eventfile_for_keyword(
258+
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
259+
260+
261+
if __name__ == "__main__":
262+
tf.test.main()

official/staging/training/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ def summary_writer(self):
193193
"""Returns the underlying summary writer."""
194194
return self._summary_writer
195195

196+
def flush(self):
197+
"""Flush the underlying summary writer."""
198+
if self._enabled:
199+
tf.summary.flush(self._summary_writer)
200+
196201
def write_summaries(self, items, always_write=True):
197202
"""Write a bulk of summaries.
198203

0 commit comments

Comments
 (0)