@@ -867,19 +867,6 @@ def testProperties(self):
867
867
self .assertEqual (variable_scope .VariableAggregation .SUM ,
868
868
replica_local .aggregation )
869
869
870
- def testTensorConversion (self ):
871
- with context .graph_mode ():
872
- _ , replica_local = _make_replica_local (
873
- variable_scope .VariableAggregation .SUM )
874
- converted = ops .convert_to_tensor (replica_local , as_ref = False )
875
- self .assertIsInstance (converted , ops .Tensor )
876
- self .assertEqual (converted .dtype , replica_local .dtype )
877
-
878
- converted = ops .convert_to_tensor (replica_local , as_ref = True )
879
- # Resources variable are converted to tensors as well when as_ref is True.
880
- self .assertIsInstance (converted , ops .Tensor )
881
- self .assertEqual (converted .dtype , replica_local .dtype )
882
-
883
870
@test_util .run_v2_only
884
871
def testCanPassToDefFun (self ):
885
872
@def_function .function
@@ -937,6 +924,20 @@ def _save(self, sess, var):
937
924
save_path , _ = self ._save_return_saver (sess , var )
938
925
return save_path
939
926
927
+ @combinations .generate (mirrored_and_tpu_strategy_combinations ())
928
+ def testTensorConversion (self , distribution ):
929
+ with context .graph_mode ():
930
+ _ , replica_local = _make_replica_local (
931
+ variable_scope .VariableAggregation .SUM , distribution )
932
+ converted = ops .convert_to_tensor (replica_local , as_ref = False )
933
+ self .assertIsInstance (converted , ops .Tensor )
934
+ self .assertEqual (converted .dtype , replica_local .dtype )
935
+
936
+ converted = ops .convert_to_tensor (replica_local , as_ref = True )
937
+ # Resources variable are converted to tensors as well when as_ref is True.
938
+ self .assertIsInstance (converted , ops .Tensor )
939
+ self .assertEqual (converted .dtype , replica_local .dtype )
940
+
940
941
@combinations .generate (mirrored_and_tpu_strategy_combinations ())
941
942
def testSaveAndRestoreReplicaLocalSumOneGraph (self , distribution ):
942
943
with self .cached_session () as sess :
@@ -1240,7 +1241,7 @@ def testReadValueInCrossReplicaContext(self, distribution,
1240
1241
]
1241
1242
for aggregation in aggregations :
1242
1243
if isinstance (distribution , _TPU_STRATEGIES ):
1243
- resolver = tpu_cluster_resolver .TPUClusterResolver ('' )
1244
+ resolver = tpu_cluster_resolver .TPUClusterResolver ("" )
1244
1245
tpu_strategy_util .initialize_tpu_system (resolver )
1245
1246
with distribution .scope ():
1246
1247
v = variable_scope .variable (
@@ -1270,6 +1271,8 @@ def assign(v=v):
1270
1271
self .assertEqual (expected , self .evaluate (v .read_value ()), aggregation )
1271
1272
self .assertEqual (expected , self .evaluate (v .value ()), aggregation )
1272
1273
self .assertEqual (expected , self .evaluate (v ), aggregation )
1274
+ self .assertEqual (expected , self .evaluate (array_ops .identity (v )),
1275
+ aggregation )
1273
1276
1274
1277
# TODO(b/145574622): Re-enable this test once ReduceOp argument is
1275
1278
# respected on GPUs.
0 commit comments