Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 434d29d

Browse files
ckkuangtensorflower-gardener
authored andcommitted
Fix keras metric.result_state when the metric variables are sharded variable.
PiperOrigin-RevId: 381292911
1 parent 1232f05 commit 434d29d

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

keras/distribute/sharded_variable_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,28 @@ def __init__(self):
108108

109109
def test_keras_metrics(self):
110110
with self.strategy.scope():
111+
fp = keras.metrics.FalsePositives(thresholds=[0.2, 0.5, 0.7, 0.8])
111112
auc = keras.metrics.AUC(num_thresholds=10)
112113

113114
@tf.function
114115
def update():
116+
fp.update_state([0., 1., 0., 0.], [0., 0., 0.3, 0.9])
115117
auc.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
116118

117119
@tf.function
118120
def reset():
121+
fp.reset_state()
119122
auc.reset_state()
120123

121124
update()
122125
self.assertEqual(auc.result(), 0.75)
126+
self.assertAllEqual(fp.result(), [2., 1., 1., 1.])
123127
reset()
124128
self.assertEqual(auc.result(), 0.0)
129+
self.assertAllEqual(fp.result(), [0., 0., 0., 0.])
125130

126131
self.assertTrue(hasattr(auc.true_positives, 'variables'))
132+
self.assertTrue(hasattr(fp.accumulator, 'variables'))
127133

128134
def test_saved_model(self):
129135

keras/metrics.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,9 +1038,9 @@ def result(self):
10381038
return tf.convert_to_tensor(result)
10391039

10401040
def reset_state(self):
1041-
num_thresholds = len(to_list(self.thresholds))
1042-
backend.batch_set_value(
1043-
[(v, np.zeros((num_thresholds,))) for v in self.variables])
1041+
backend.batch_set_value([
1042+
(v, np.zeros(v.shape.as_list())) for v in self.variables
1043+
])
10441044

10451045
def get_config(self):
10461046
config = {'thresholds': self.init_thresholds}
@@ -3175,8 +3175,9 @@ def result(self):
31753175

31763176
def reset_state(self):
31773177
if self._built:
3178-
backend.batch_set_value(
3179-
[(v, np.zeros(self._shape.as_list())) for v in self.variables])
3178+
backend.batch_set_value([
3179+
(v, np.zeros(v.shape.as_list())) for v in self.variables
3180+
])
31803181

31813182

31823183
@keras_export('keras.metrics.BinaryCrossentropy')

0 commit comments

Comments
 (0)