Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fea907c

Browse files
authored
Reduce redundancy on saving model (keras-team#18871)
* Fix saving_api.py for merging changelog of model.save() I took this code out of the .keras model saving logic. It is more reasonable for checking file overwriting each of saving types that .keras, .h5, .hdf5 . And When comparing the two versions, this code is more recent version. * Fix method docstring and warning description of keras.saving.save_model() and model.save() In deprectaion warning, It notice only the case of calling model.save(). so I added the case of keras.saving.save_model And method docstring too. Additionally, I added deprecation warning in Args section for save_format. * Remove redundancy of model.save() * Update test assersion in test_h5_deprecation_warning() of saving_api_test.py
1 parent 10252a9 commit fea907c

File tree

3 files changed

+28
-77
lines changed

3 files changed

+28
-77
lines changed

keras/models/model.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from keras import utils
88
from keras.api_export import keras_export
99
from keras.layers.layer import Layer
10-
from keras.legacy.saving import legacy_h5_format
1110
from keras.models.variable_mapping import map_trackable_variables
1211
from keras.saving import saving_api
1312
from keras.saving import saving_lib
@@ -269,13 +268,14 @@ def save(self, filepath, overwrite=True, **kwargs):
269268
"""Saves a model as a `.keras` file.
270269
271270
Args:
272-
filepath: `str` or `pathlib.Path` object.
273-
Path where to save the model. Must end in `.keras`.
274-
overwrite: Whether we should overwrite any existing model
275-
at the target location, or instead ask the user
276-
via an interactive prompt.
277-
save_format: Format to use, as a string. Only the `"keras"`
278-
format is supported at this time.
271+
filepath: `str` or `pathlib.Path` object. Path where to save
272+
the model. Must end in `.keras`.
273+
overwrite: Whether we should overwrite any existing model at
274+
the target location, or instead ask the user via
275+
an interactive prompt.
276+
save_format: The `save_format` argument is deprecated in Keras 3.
277+
Format to use, as a string. Only the `"keras"` format is
278+
supported at this time.
279279
280280
Example:
281281
@@ -292,8 +292,7 @@ def save(self, filepath, overwrite=True, **kwargs):
292292
assert np.allclose(model.predict(x), loaded_model.predict(x))
293293
```
294294
295-
Note that `model.save()` is an alias for
296-
`keras.saving.save_model()`.
295+
Note that `model.save()` is an alias for `keras.saving.save_model()`.
297296
298297
The saved `.keras` file contains:
299298
@@ -303,60 +302,7 @@ def save(self, filepath, overwrite=True, **kwargs):
303302
304303
Thus models can be reinstantiated in the exact same state.
305304
"""
306-
include_optimizer = kwargs.pop("include_optimizer", True)
307-
save_format = kwargs.pop("save_format", None)
308-
if kwargs:
309-
raise ValueError(
310-
"The following argument(s) are not supported: "
311-
f"{list(kwargs.keys())}"
312-
)
313-
if save_format:
314-
if str(filepath).endswith((".h5", ".hdf5")) or str(
315-
filepath
316-
).endswith(".keras"):
317-
warnings.warn(
318-
"The `save_format` argument is deprecated in Keras 3. "
319-
"We recommend removing this argument as it can be inferred "
320-
"from the file path. "
321-
f"Received: save_format={save_format}"
322-
)
323-
else:
324-
raise ValueError(
325-
"The `save_format` argument is deprecated in Keras 3. "
326-
"Please remove this argument and pass a file path with "
327-
"either `.keras` or `.h5` extension."
328-
f"Received: save_format={save_format}"
329-
)
330-
try:
331-
exists = os.path.exists(filepath)
332-
except TypeError:
333-
exists = False
334-
if exists and not overwrite:
335-
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
336-
if not proceed:
337-
return
338-
if str(filepath).endswith(".keras"):
339-
saving_lib.save_model(self, filepath)
340-
elif str(filepath).endswith((".h5", ".hdf5")):
341-
# Deprecation warnings
342-
warnings.warn(
343-
"You are saving your model as an HDF5 file via `model.save()`. "
344-
"This file format is considered legacy. "
345-
"We recommend using instead the native Keras format, "
346-
"e.g. `model.save('my_model.keras')`."
347-
)
348-
legacy_h5_format.save_model_to_hdf5(
349-
self, filepath, overwrite, include_optimizer
350-
)
351-
else:
352-
raise ValueError(
353-
"Invalid filepath extension for saving. "
354-
"Please add either a `.keras` extension for the native Keras "
355-
f"format (recommended) or a `.h5` extension. "
356-
"Use `tf.saved_model.save()` if you want to export a "
357-
"SavedModel for use with TFLite/TFServing/etc. "
358-
f"Received: filepath={filepath}."
359-
)
305+
return saving_api.save_model(self, filepath, overwrite, **kwargs)
360306

361307
@traceback_utils.filter_traceback
362308
def save_weights(self, filepath, overwrite=True):

keras/saving/saving_api.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,25 @@ def save_model(model, filepath, overwrite=True, **kwargs):
7878
# Deprecation warnings
7979
if str(filepath).endswith((".h5", ".hdf5")):
8080
logging.warning(
81-
"You are saving your model as an HDF5 file via `model.save()`. "
81+
"You are saving your model as an HDF5 file via "
82+
"`model.save()` or `keras.saving.save_model(model)`. "
8283
"This file format is considered legacy. "
8384
"We recommend using instead the native Keras format, "
84-
"e.g. `model.save('my_model.keras')`."
85+
"e.g. `model.save('my_model.keras')` or "
86+
"`keras.saving.save_model(model, 'my_model.keras')`. "
8587
)
8688

89+
# If file exists and should not be overwritten.
90+
try:
91+
exists = os.path.exists(filepath)
92+
except TypeError:
93+
exists = False
94+
if exists and not overwrite:
95+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
96+
if not proceed:
97+
return
98+
8799
if str(filepath).endswith(".keras"):
88-
# If file exists and should not be overwritten.
89-
try:
90-
exists = os.path.exists(filepath)
91-
except TypeError:
92-
exists = False
93-
if exists and not overwrite:
94-
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
95-
if not proceed:
96-
return
97100
saving_lib.save_model(model, filepath)
98101
elif str(filepath).endswith((".h5", ".hdf5")):
99102
legacy_h5_format.save_model_to_hdf5(

keras/saving/saving_api_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ def test_h5_deprecation_warning(self):
171171
with mock.patch.object(logging, "warning") as mock_warn:
172172
saving_api.save_model(model, filepath)
173173
mock_warn.assert_called_once_with(
174-
"You are saving your model as an HDF5 file via `model.save()`. "
174+
"You are saving your model as an HDF5 file via "
175+
"`model.save()` or `keras.saving.save_model(model)`. "
175176
"This file format is considered legacy. "
176177
"We recommend using instead the native Keras format, "
177-
"e.g. `model.save('my_model.keras')`."
178+
"e.g. `model.save('my_model.keras')` or "
179+
"`keras.saving.save_model(model, 'my_model.keras')`. "
178180
)

0 commit comments

Comments
 (0)