Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 68a9002

Browse files
authored
Merge pull request tensorflow#4827 from asimshankar/mnist-eager
[official/mnist]: Avoid some now unnecessary 'tfe' symbols.
2 parents 71c196c + 612ec83 commit 68a9002

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

official/mnist/mnist_eager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from absl import app as absl_app
3434
from absl import flags
3535
import tensorflow as tf
36-
import tensorflow.contrib.eager as tfe
3736
# pylint: enable=g-bad-import-order
3837

3938
from official.mnist import dataset as mnist_dataset
@@ -42,6 +41,8 @@
4241
from official.utils.misc import model_helpers
4342

4443

44+
tfe = tf.contrib.eager
45+
4546
def loss(logits, labels):
4647
return tf.reduce_mean(
4748
tf.nn.sparse_softmax_cross_entropy_with_logits(
@@ -60,7 +61,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
6061
"""Trains model on `dataset` using `optimizer`."""
6162

6263
start = time.time()
63-
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
64+
for (batch, (images, labels)) in enumerate(dataset):
6465
with tf.contrib.summary.record_summaries_every_n_global_steps(
6566
10, global_step=step_counter):
6667
# Record the operations used to compute the loss given the input,
@@ -85,7 +86,7 @@ def test(model, dataset):
8586
avg_loss = tfe.metrics.Mean('loss')
8687
accuracy = tfe.metrics.Accuracy('accuracy')
8788

88-
for (images, labels) in tfe.Iterator(dataset):
89+
for (images, labels) in dataset:
8990
logits = model(images, training=False)
9091
avg_loss(loss(logits, labels))
9192
accuracy(
@@ -145,7 +146,7 @@ def run_mnist_eager(flags_obj):
145146
# Create and restore checkpoint (if one exists on the path)
146147
checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
147148
step_counter = tf.train.get_or_create_global_step()
148-
checkpoint = tfe.Checkpoint(
149+
checkpoint = tf.train.Checkpoint(
149150
model=model, optimizer=optimizer, step_counter=step_counter)
150151
# Restore variables on creation if a checkpoint exists.
151152
checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))

0 commit comments

Comments
 (0)