Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e78c47f

Browse files
rxsangtensorflower-gardener
authored andcommitted
Fix calling tf.identity on SyncOnReadVariable in cross replica context.
PiperOrigin-RevId: 293881729 Change-Id: I94bfcd4380767cdbbacf30ac5a29ee7bb2b992fa
1 parent 6957c7b commit e78c47f

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

tensorflow/python/distribute/values.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,8 +1139,9 @@ def _saveable_factory(name=self._common_name):
11391139

11401140
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
11411141
"""Converts a variable to a tensor."""
1142-
return ops.convert_to_tensor(
1143-
self._get(), dtype=dtype, name=name, as_ref=as_ref)
1142+
with _enter_or_assert_strategy(self._distribute_strategy):
1143+
return ops.convert_to_tensor(
1144+
self._get(), dtype=dtype, name=name, as_ref=as_ref)
11441145

11451146

11461147
# Register a conversion function for SyncOnReadVariable which allows as_ref to

tensorflow/python/distribute/values_test.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -867,19 +867,6 @@ def testProperties(self):
867867
self.assertEqual(variable_scope.VariableAggregation.SUM,
868868
replica_local.aggregation)
869869

870-
def testTensorConversion(self):
871-
with context.graph_mode():
872-
_, replica_local = _make_replica_local(
873-
variable_scope.VariableAggregation.SUM)
874-
converted = ops.convert_to_tensor(replica_local, as_ref=False)
875-
self.assertIsInstance(converted, ops.Tensor)
876-
self.assertEqual(converted.dtype, replica_local.dtype)
877-
878-
converted = ops.convert_to_tensor(replica_local, as_ref=True)
879-
# Resources variable are converted to tensors as well when as_ref is True.
880-
self.assertIsInstance(converted, ops.Tensor)
881-
self.assertEqual(converted.dtype, replica_local.dtype)
882-
883870
@test_util.run_v2_only
884871
def testCanPassToDefFun(self):
885872
@def_function.function
@@ -937,6 +924,20 @@ def _save(self, sess, var):
937924
save_path, _ = self._save_return_saver(sess, var)
938925
return save_path
939926

927+
@combinations.generate(mirrored_and_tpu_strategy_combinations())
928+
def testTensorConversion(self, distribution):
929+
with context.graph_mode():
930+
_, replica_local = _make_replica_local(
931+
variable_scope.VariableAggregation.SUM, distribution)
932+
converted = ops.convert_to_tensor(replica_local, as_ref=False)
933+
self.assertIsInstance(converted, ops.Tensor)
934+
self.assertEqual(converted.dtype, replica_local.dtype)
935+
936+
converted = ops.convert_to_tensor(replica_local, as_ref=True)
937+
# Resources variable are converted to tensors as well when as_ref is True.
938+
self.assertIsInstance(converted, ops.Tensor)
939+
self.assertEqual(converted.dtype, replica_local.dtype)
940+
940941
@combinations.generate(mirrored_and_tpu_strategy_combinations())
941942
def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
942943
with self.cached_session() as sess:
@@ -1240,7 +1241,7 @@ def testReadValueInCrossReplicaContext(self, distribution,
12401241
]
12411242
for aggregation in aggregations:
12421243
if isinstance(distribution, _TPU_STRATEGIES):
1243-
resolver = tpu_cluster_resolver.TPUClusterResolver('')
1244+
resolver = tpu_cluster_resolver.TPUClusterResolver("")
12441245
tpu_strategy_util.initialize_tpu_system(resolver)
12451246
with distribution.scope():
12461247
v = variable_scope.variable(
@@ -1270,6 +1271,8 @@ def assign(v=v):
12701271
self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
12711272
self.assertEqual(expected, self.evaluate(v.value()), aggregation)
12721273
self.assertEqual(expected, self.evaluate(v), aggregation)
1274+
self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
1275+
aggregation)
12731276

12741277
# TODO(b/145574622): Re-enable this test once ReduceOp argument is
12751278
# respected on GPUs.

0 commit comments

Comments
 (0)