33
33
from absl import app as absl_app
34
34
from absl import flags
35
35
import tensorflow as tf
36
- import tensorflow .contrib .eager as tfe
37
36
# pylint: enable=g-bad-import-order
38
37
39
38
from official .mnist import dataset as mnist_dataset
42
41
from official .utils .misc import model_helpers
43
42
44
43
44
+ tfe = tf .contrib .eager
45
+
45
46
def loss (logits , labels ):
46
47
return tf .reduce_mean (
47
48
tf .nn .sparse_softmax_cross_entropy_with_logits (
@@ -60,7 +61,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
60
61
"""Trains model on `dataset` using `optimizer`."""
61
62
62
63
start = time .time ()
63
- for (batch , (images , labels )) in enumerate (tfe . Iterator ( dataset ) ):
64
+ for (batch , (images , labels )) in enumerate (dataset ):
64
65
with tf .contrib .summary .record_summaries_every_n_global_steps (
65
66
10 , global_step = step_counter ):
66
67
# Record the operations used to compute the loss given the input,
@@ -85,7 +86,7 @@ def test(model, dataset):
85
86
avg_loss = tfe .metrics .Mean ('loss' )
86
87
accuracy = tfe .metrics .Accuracy ('accuracy' )
87
88
88
- for (images , labels ) in tfe . Iterator ( dataset ) :
89
+ for (images , labels ) in dataset :
89
90
logits = model (images , training = False )
90
91
avg_loss (loss (logits , labels ))
91
92
accuracy (
@@ -145,7 +146,7 @@ def run_mnist_eager(flags_obj):
145
146
# Create and restore checkpoint (if one exists on the path)
146
147
checkpoint_prefix = os .path .join (flags_obj .model_dir , 'ckpt' )
147
148
step_counter = tf .train .get_or_create_global_step ()
148
- checkpoint = tfe .Checkpoint (
149
+ checkpoint = tf . train .Checkpoint (
149
150
model = model , optimizer = optimizer , step_counter = step_counter )
150
151
# Restore variables on creation if a checkpoint exists.
151
152
checkpoint .restore (tf .train .latest_checkpoint (flags_obj .model_dir ))
0 commit comments