Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1abd9a1

Browse files
mattdangerwtensorflower-gardener
authored andcommitted
Remove reset_state from adapt
It is only support for Discretization and Normalization and not even tested on those classes. Removing it gives us a cleaner API interface for release. PiperOrigin-RevId: 380906031
1 parent f224f33 commit 1abd9a1

15 files changed

+20
-139
lines changed

keras/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ tf_class {
9696
name: "stateful"
9797
mtype: "<type \'property\'>"
9898
}
99-
member {
100-
name: "streaming"
101-
mtype: "<type \'property\'>"
102-
}
10399
member {
104100
name: "submodules"
105101
mtype: "<type \'property\'>"
@@ -142,7 +138,7 @@ tf_class {
142138
}
143139
member_method {
144140
name: "adapt"
145-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
141+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
146142
}
147143
member_method {
148144
name: "add_loss"

keras/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ tf_class {
9696
name: "stateful"
9797
mtype: "<type \'property\'>"
9898
}
99-
member {
100-
name: "streaming"
101-
mtype: "<type \'property\'>"
102-
}
10399
member {
104100
name: "submodules"
105101
mtype: "<type \'property\'>"
@@ -142,7 +138,7 @@ tf_class {
142138
}
143139
member_method {
144140
name: "adapt"
145-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
141+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
146142
}
147143
member_method {
148144
name: "add_loss"

keras/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,6 @@ tf_class {
9595
name: "stateful"
9696
mtype: "<type \'property\'>"
9797
}
98-
member {
99-
name: "streaming"
100-
mtype: "<type \'property\'>"
101-
}
10298
member {
10399
name: "submodules"
104100
mtype: "<type \'property\'>"
@@ -137,11 +133,11 @@ tf_class {
137133
}
138134
member_method {
139135
name: "__init__"
140-
argspec: "args=[\'self\', \'streaming\'], varargs=None, keywords=kwargs, defaults=[\'True\'], "
136+
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
141137
}
142138
member_method {
143139
name: "adapt"
144-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
140+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
145141
}
146142
member_method {
147143
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ tf_class {
9696
name: "stateful"
9797
mtype: "<type \'property\'>"
9898
}
99-
member {
100-
name: "streaming"
101-
mtype: "<type \'property\'>"
102-
}
10399
member {
104100
name: "submodules"
105101
mtype: "<type \'property\'>"
@@ -142,7 +138,7 @@ tf_class {
142138
}
143139
member_method {
144140
name: "adapt"
145-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
141+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
146142
}
147143
member_method {
148144
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ tf_class {
9797
name: "stateful"
9898
mtype: "<type \'property\'>"
9999
}
100-
member {
101-
name: "streaming"
102-
mtype: "<type \'property\'>"
103-
}
104100
member {
105101
name: "submodules"
106102
mtype: "<type \'property\'>"
@@ -143,7 +139,7 @@ tf_class {
143139
}
144140
member_method {
145141
name: "adapt"
146-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
142+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
147143
}
148144
member_method {
149145
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ tf_class {
9696
name: "stateful"
9797
mtype: "<type \'property\'>"
9898
}
99-
member {
100-
name: "streaming"
101-
mtype: "<type \'property\'>"
102-
}
10399
member {
104100
name: "submodules"
105101
mtype: "<type \'property\'>"
@@ -142,7 +138,7 @@ tf_class {
142138
}
143139
member_method {
144140
name: "adapt"
145-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
141+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
146142
}
147143
member_method {
148144
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,6 @@ tf_class {
9595
name: "stateful"
9696
mtype: "<type \'property\'>"
9797
}
98-
member {
99-
name: "streaming"
100-
mtype: "<type \'property\'>"
101-
}
10298
member {
10399
name: "submodules"
104100
mtype: "<type \'property\'>"
@@ -137,11 +133,11 @@ tf_class {
137133
}
138134
member_method {
139135
name: "__init__"
140-
argspec: "args=[\'self\', \'streaming\'], varargs=None, keywords=kwargs, defaults=[\'True\'], "
136+
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
141137
}
142138
member_method {
143139
name: "adapt"
144-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
140+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
145141
}
146142
member_method {
147143
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ tf_class {
9797
name: "stateful"
9898
mtype: "<type \'property\'>"
9999
}
100-
member {
101-
name: "streaming"
102-
mtype: "<type \'property\'>"
103-
}
104100
member {
105101
name: "submodules"
106102
mtype: "<type \'property\'>"
@@ -143,7 +139,7 @@ tf_class {
143139
}
144140
member_method {
145141
name: "adapt"
146-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
142+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
147143
}
148144
member_method {
149145
name: "add_loss"

keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ tf_class {
9696
name: "stateful"
9797
mtype: "<type \'property\'>"
9898
}
99-
member {
100-
name: "streaming"
101-
mtype: "<type \'property\'>"
102-
}
10399
member {
104100
name: "submodules"
105101
mtype: "<type \'property\'>"
@@ -142,7 +138,7 @@ tf_class {
142138
}
143139
member_method {
144140
name: "adapt"
145-
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
141+
argspec: "args=[\'self\', \'data\', \'batch_size\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
146142
}
147143
member_method {
148144
name: "add_loss"

keras/engine/base_preprocessing_layer.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,11 @@ class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
5050
5151
The `PreprocessingLayer` class is the base class you would subclass to
5252
implement your own preprocessing layers.
53-
54-
Attributes:
55-
streaming: Whether a layer can be adapted multiple times without resetting
56-
the state of the layer.
5753
"""
5854
_must_restore_from_config = True
5955

60-
def __init__(self, streaming=True, **kwargs):
56+
def __init__(self, **kwargs):
6157
super(PreprocessingLayer, self).__init__(**kwargs)
62-
self._streaming = streaming
6358
self._is_compiled = False
6459
self._is_adapted = False
6560

@@ -69,11 +64,6 @@ def __init__(self, streaming=True, **kwargs):
6964

7065
self._adapt_function = None
7166

72-
@property
73-
def streaming(self):
74-
"""Whether `adapt` can be called twice without resetting the state."""
75-
return self._streaming
76-
7767
@property
7868
def is_adapted(self):
7969
"""Whether the layer has been fit to data already."""
@@ -163,7 +153,7 @@ def compile(self, run_eagerly=None, steps_per_execution=None):
163153

164154
self._is_compiled = True
165155

166-
def adapt(self, data, batch_size=None, steps=None, reset_state=True):
156+
def adapt(self, data, batch_size=None, steps=None):
167157
"""Fits the state of the preprocessing layer to the data being passed.
168158
169159
After calling `adapt` on a layer, a preprocessing layer's state will not
@@ -232,21 +222,13 @@ def adapt(self, data, batch_size=None, steps=None, reset_state=True):
232222
the input dataset is exhausted. When passing an infinitely
233223
repeating dataset, you must specify the `steps` argument. This
234224
argument is not supported with array inputs.
235-
reset_state: Optional argument specifying whether to clear the state of
236-
the layer at the start of the call to `adapt`, or whether to start
237-
from the existing state. This argument may not be relevant to all
238-
preprocessing layers: a subclass of PreprocessingLayer may choose to
239-
throw if 'reset_state' is set to False.
240225
"""
241226
_disallow_inside_tf_function('adapt')
242227
if not version_utils.should_use_v2():
243228
raise RuntimeError('`adapt` is only supported in tensorflow v2.') # pylint: disable=g-doc-exception
244-
if not self.streaming and self._is_adapted and not reset_state:
245-
raise ValueError('{} does not supporting calling `adapt` twice without '
246-
'resetting the state.'.format(self.__class__.__name__))
247229
if not self._is_compiled:
248230
self.compile() # Compile with defaults.
249-
if self.built and reset_state:
231+
if self.built:
250232
self.reset_state()
251233
data_handler = data_adapter.DataHandler(
252234
data,
@@ -345,11 +327,9 @@ def compile(self, run_eagerly=None, steps_per_execution=None):
345327
super(CombinerPreprocessingLayer, self).compile(
346328
run_eagerly=run_eagerly, steps_per_execution=steps_per_execution)
347329

348-
def adapt(self, data, batch_size=None, steps=None, reset_state=True):
349-
if not reset_state:
350-
self._adapt_accumulator = self._combiner.restore(self._restore_updates())
330+
def adapt(self, data, batch_size=None, steps=None):
351331
super(CombinerPreprocessingLayer, self).adapt(
352-
data, batch_size=batch_size, steps=steps, reset_state=reset_state)
332+
data, batch_size=batch_size, steps=steps)
353333

354334
def _add_state_variable(self,
355335
name,

keras/engine/base_preprocessing_layer_test.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -248,44 +248,6 @@ def test_post_build_adapt_update_dataset(self):
248248

249249
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
250250

251-
def test_further_tuning(self):
252-
"""Test that models can be tuned with multiple calls to 'adapt'."""
253-
254-
input_dataset = np.array([1, 2, 3, 4, 5])
255-
256-
layer = AddingPreprocessingLayer()
257-
layer.adapt(input_dataset)
258-
259-
input_data = keras.Input(shape=(1,))
260-
output = layer(input_data)
261-
model = keras.Model(input_data, output)
262-
model._run_eagerly = testing_utils.should_run_eagerly()
263-
264-
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
265-
266-
layer.adapt(np.array([1, 2]), reset_state=False)
267-
self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
268-
269-
def test_further_tuning_post_injection(self):
270-
"""Test that models can be tuned with multiple calls to 'adapt'."""
271-
272-
input_dataset = np.array([1, 2, 3, 4, 5])
273-
274-
layer = AddingPreprocessingLayer()
275-
276-
input_data = keras.Input(shape=(1,))
277-
output = layer(input_data)
278-
model = keras.Model(input_data, output)
279-
model._run_eagerly = testing_utils.should_run_eagerly()
280-
281-
combiner = layer._combiner
282-
updates = combiner.extract(combiner.compute(input_dataset))
283-
layer._set_state_variables(updates)
284-
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
285-
286-
layer.adapt(np.array([1, 2]), reset_state=False)
287-
self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
288-
289251
def test_weight_based_state_transfer(self):
290252
"""Test that preproc layers can transfer state via get/set weights.."""
291253

@@ -311,31 +273,6 @@ def get_model():
311273
model_2.set_weights(weights)
312274
self.assertAllEqual([[16], [17], [18]], model_2.predict([1., 2., 3.]))
313275

314-
def test_weight_based_state_transfer_with_further_tuning(self):
315-
"""Test that transferred state can be used to further tune a model.."""
316-
317-
def get_model():
318-
input_data = keras.Input(shape=(1,))
319-
layer = AddingPreprocessingLayer()
320-
output = layer(input_data)
321-
model = keras.Model(input_data, output)
322-
model._run_eagerly = testing_utils.should_run_eagerly()
323-
return (model, layer)
324-
325-
input_dataset = np.array([1, 2, 3, 4, 5])
326-
model, layer = get_model()
327-
layer.adapt(input_dataset)
328-
self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
329-
330-
# Transfer state from model to model_2 via get/set weights.
331-
weights = model.get_weights()
332-
model_2, layer_2 = get_model()
333-
model_2.set_weights(weights)
334-
335-
# Further adapt this layer based on the transferred weights.
336-
layer_2.adapt(np.array([1, 2]), reset_state=False)
337-
self.assertAllEqual([[19], [20], [21]], model_2.predict([1., 2., 3.]))
338-
339276
def test_loading_without_providing_class_fails(self):
340277
input_data = keras.Input(shape=(1,))
341278
layer = AddingPreprocessingLayer()

keras/layers/preprocessing/discretization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self,
181181
elif bin_boundaries is None:
182182
bin_boundaries = kwargs["bins"]
183183
del kwargs["bins"]
184-
super().__init__(streaming=True, **kwargs)
184+
super().__init__(**kwargs)
185185
base_preprocessing_layer.keras_kpl_gauge.get_cell("Discretization").set(
186186
True)
187187
if num_bins is not None and num_bins < 0:

keras/layers/preprocessing/index_lookup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(self,
210210
kwargs.pop("vocabulary_size", None)
211211
kwargs.pop("has_static_table", None)
212212

213-
super().__init__(streaming=False, **kwargs)
213+
super().__init__(**kwargs)
214214

215215
if invert:
216216
self._key_dtype = tf.int64

keras/layers/preprocessing/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Normalization(base_preprocessing_layer.PreprocessingLayer):
9393
"""
9494

9595
def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
96-
super().__init__(streaming=True, **kwargs)
96+
super().__init__(**kwargs)
9797
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
9898

9999
# Standardize `axis` to a tuple.

keras/layers/preprocessing/text_vectorization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __init__(self,
304304
# Drop deprecated config options.
305305
kwargs.pop("vocabulary_size", None)
306306

307-
super().__init__(streaming=False, **kwargs)
307+
super().__init__(**kwargs)
308308
base_preprocessing_layer.keras_kpl_gauge.get_cell("TextVectorization").set(
309309
True)
310310

0 commit comments

Comments
 (0)