Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 04fede8

Browse files
Jonathan Hseutensorflower-gardener
authored andcommitted
Fix StopAtStepHook with num_steps when multiple steps are executed in a single
session.run(). PiperOrigin-RevId: 157277945
1 parent 53f8e36 commit 04fede8

File tree

7 files changed

+88
-48
lines changed

7 files changed

+88
-48
lines changed

tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,8 @@ def input_fn():
10431043
fix_global_step_increment_bug=False)
10441044
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
10451045

1046-
# Expected is 100, but because of the global step increment bug, this is 51.
1047-
self.assertEqual(51, step_counter.steps)
1046+
# Expected is 100, but because of the global step increment bug, this is 50.
1047+
self.assertEqual(50, step_counter.steps)
10481048

10491049
def testGlobalStepDNNLinearCombinedBugFixed(self):
10501050
"""Tests global step update for dnn-linear combined model."""
@@ -1785,14 +1785,14 @@ def feature_engineering_fn(features, labels):
17851785
dnn_hidden_units=[3, 3],
17861786
config=run_config.RunConfig(tf_random_seed=1),
17871787
feature_engineering_fn=feature_engineering_fn)
1788-
estimator_with_fe_fn.fit(input_fn=input_fn, steps=100)
1788+
estimator_with_fe_fn.fit(input_fn=input_fn, steps=110)
17891789

17901790
estimator_without_fe_fn = dnn_linear_combined.DNNLinearCombinedRegressor(
17911791
linear_feature_columns=[feature_column.real_valued_column('x')],
17921792
dnn_feature_columns=[feature_column.real_valued_column('x')],
17931793
dnn_hidden_units=[3, 3],
17941794
config=run_config.RunConfig(tf_random_seed=1))
1795-
estimator_without_fe_fn.fit(input_fn=input_fn, steps=100)
1795+
estimator_without_fe_fn.fit(input_fn=input_fn, steps=110)
17961796

17971797
# predictions = y
17981798
prediction_with_fe_fn = next(

tensorflow/contrib/learn/python/learn/estimators/estimator_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,9 @@ def testInvalidModelFn_no_train_op(self):
396396
def _invalid_model_fn(features, labels):
397397
# pylint: disable=unused-argument
398398
w = variables_lib.Variable(42.0, 'weight')
399-
loss = 100.0 - w
399+
update_global_step = variables.get_global_step().assign_add(1)
400+
with control_flow_ops.control_dependencies([update_global_step]):
401+
loss = 100.0 - w
400402
return None, loss, None
401403

402404
est = estimator.Estimator(model_fn=_invalid_model_fn)
@@ -409,7 +411,9 @@ def _invalid_model_fn(features, labels, mode):
409411
# pylint: disable=unused-argument
410412
w = variables_lib.Variable(42.0, 'weight')
411413
loss = 100.0 - w
412-
train_op = w.assign_add(loss / 100.0)
414+
update_global_step = variables.get_global_step().assign_add(1)
415+
with control_flow_ops.control_dependencies([update_global_step]):
416+
train_op = w.assign_add(loss / 100.0)
413417
predictions = loss
414418
if mode == model_fn.ModeKeys.EVAL:
415419
loss = None
@@ -426,7 +430,9 @@ def _invalid_model_fn(features, labels):
426430
# pylint: disable=unused-argument
427431
w = variables_lib.Variable(42.0, 'weight')
428432
loss = 100.0 - w
429-
train_op = w.assign_add(loss / 100.0)
433+
update_global_step = variables.get_global_step().assign_add(1)
434+
with control_flow_ops.control_dependencies([update_global_step]):
435+
train_op = w.assign_add(loss / 100.0)
430436
return None, loss, train_op
431437

432438
est = estimator.Estimator(model_fn=_invalid_model_fn)

tensorflow/contrib/learn/python/learn/estimators/estimators_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222

2323
import numpy as np
2424

25+
from tensorflow.contrib.framework.python.ops import variables
2526
from tensorflow.contrib.learn.python import learn
2627
from tensorflow.contrib.learn.python.learn import datasets
2728
from tensorflow.contrib.learn.python.learn import metric_spec
2829
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
2930
from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
3031
from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split
3132
from tensorflow.python.framework import constant_op
32-
from tensorflow.python.ops import control_flow_ops
33-
from tensorflow.python.ops import variables
33+
from tensorflow.python.ops import variables as variables_lib
3434
from tensorflow.python.platform import test
3535
from tensorflow.python.training import momentum as momentum_lib
3636

@@ -57,11 +57,12 @@ def feature_engineering_fn(features, labels):
5757

5858
def model_fn(features, labels):
5959
# dummy variable:
60-
_ = variables.Variable([0.])
60+
_ = variables_lib.Variable([0.])
6161
_ = labels
6262
predictions = features["transformed_x"]
6363
loss = constant_op.constant([2.])
64-
return predictions, loss, control_flow_ops.no_op()
64+
update_global_step = variables.get_global_step().assign_add(1)
65+
return predictions, loss, update_global_step
6566

6667
estimator = estimator_lib.Estimator(
6768
model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
@@ -95,11 +96,12 @@ def feature_engineering_fn(features, labels):
9596

9697
def model_fn(features, labels):
9798
# dummy variable:
98-
_ = variables.Variable([0.])
99+
_ = variables_lib.Variable([0.])
99100
_ = labels
100101
predictions = features["x"]
101102
loss = constant_op.constant([2.])
102-
return predictions, loss, control_flow_ops.no_op()
103+
update_global_step = variables.get_global_step().assign_add(1)
104+
return predictions, loss, update_global_step
103105

104106
estimator_with_fe_fn = estimator_lib.Estimator(
105107
model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)

0 commit comments

Comments
 (0)