Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dea7ecf

Browse files
authored
Merge pull request tensorflow#4143 from asimshankar/mnist-eager-1.8
official/mnist: Updates with the release of TensorFlow 1.8.
2 parents 505f554 + ee968ba commit dea7ecf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

official/mnist/mnist_eager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
6363
# Record the operations used to compute the loss given the input,
6464
# so that the gradient of the loss with respect to the variables
6565
# can be computed.
66-
with tfe.GradientTape() as tape:
66+
with tf.GradientTape() as tape:
6767
logits = model(images, training=True)
6868
loss_value = loss(logits, labels)
6969
tf.contrib.summary.scalar('loss', loss_value)
@@ -99,11 +99,11 @@ def main(argv):
9999
parser = MNISTEagerArgParser()
100100
flags = parser.parse_args(args=argv[1:])
101101

102-
tfe.enable_eager_execution()
102+
tf.enable_eager_execution()
103103

104104
# Automatically determine device and data_format
105105
(device, data_format) = ('/gpu:0', 'channels_first')
106-
if flags.no_gpu or tfe.num_gpus() <= 0:
106+
if flags.no_gpu or not tf.test.is_gpu_available():
107107
(device, data_format) = ('/cpu:0', 'channels_last')
108108
# If data_format is defined in FLAGS, overwrite automatically set value.
109109
if flags.data_format is not None:

0 commit comments

Comments
 (0)