diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 2e1c639e267b7..cf3302ad62c00 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -12,6 +12,18 @@ Version 0.21.3 Changelog --------- +:mod:`sklearn.datasets` +....................... + +- |Fix| :func:`datasets.fetch_california_housing`, + :func:`datasets.fetch_covtype`, + :func:`datasets.fetch_kddcup99`, :func:`datasets.fetch_olivetti_faces`, + :func:`datasets.fetch_rcv1`, and :func:`datasets.fetch_species_distributions` + try to persist the previously cache using the new ``joblib`` if the cahced + data was persisted using the deprecated ``sklearn.externals.joblib``. This + behavior is set to be deprecated and removed in v0.23. + :pr:`14197` by `Adrin Jalali`_. + :mod:`sklearn.impute` ..................... diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 0b8f73c86117b..c353746c1c326 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -10,6 +10,7 @@ import csv import sys import shutil +import warnings from collections import namedtuple from os import environ, listdir, makedirs from os.path import dirname, exists, expanduser, isdir, join, splitext @@ -919,3 +920,31 @@ def _fetch_remote(remote, dirname=None): "file may be corrupted.".format(file_path, checksum, remote.checksum)) return file_path + + +def _refresh_cache(files, compress): + # TODO: REMOVE in v0.23 + import joblib + msg = "sklearn.externals.joblib is deprecated in 0.21" + with warnings.catch_warnings(record=True) as warns: + data = tuple([joblib.load(f) for f in files]) + + refresh_needed = any([str(x.message).startswith(msg) for x in warns]) + + other_warns = [w for w in warns if not str(w.message).startswith(msg)] + for w in other_warns: + warnings.warn(message=w.message, category=w.category) + + if refresh_needed: + try: + for value, path in zip(data, files): + joblib.dump(value, path, compress=compress) + except IOError: + message = ("This dataset will stop being loadable in scikit-learn " + "version 0.23 because it references a deprecated " + "import path. Consider removing the following files " + "and allowing it to be cached anew:\n%s" + % ("\n".join(files))) + warnings.warn(message=message, category=DeprecationWarning) + + return data[0] if len(data) == 1 else data diff --git a/sklearn/datasets/california_housing.py b/sklearn/datasets/california_housing.py index 35f0847c1de05..7d8b1aa3ede45 100644 --- a/sklearn/datasets/california_housing.py +++ b/sklearn/datasets/california_housing.py @@ -34,6 +34,7 @@ from .base import _fetch_remote from .base import _pkl_filepath from .base import RemoteFileMetadata +from .base import _refresh_cache from ..utils import Bunch # The original data can be found at: @@ -129,7 +130,9 @@ def fetch_california_housing(data_home=None, download_if_missing=True, remove(archive_path) else: - cal_housing = joblib.load(filepath) + cal_housing = _refresh_cache([filepath], 6) + # TODO: Revert to the following line in v0.23 + # cal_housing = joblib.load(filepath) feature_names = ["MedInc", "HouseAge", "AveRooms", "AveBedrms", "Population", "AveOccup", "Latitude", "Longitude"] diff --git a/sklearn/datasets/covtype.py b/sklearn/datasets/covtype.py index 9d995810bee3f..4108b1d79f84b 100644 --- a/sklearn/datasets/covtype.py +++ b/sklearn/datasets/covtype.py @@ -25,6 +25,7 @@ from .base import get_data_home from .base import _fetch_remote from .base import RemoteFileMetadata +from .base import _refresh_cache from ..utils import Bunch from .base import _pkl_filepath from ..utils import check_random_state @@ -125,8 +126,10 @@ def fetch_covtype(data_home=None, download_if_missing=True, try: X, y except NameError: - X = joblib.load(samples_path) - y = joblib.load(targets_path) + X, y = _refresh_cache([samples_path, targets_path], 9) + # TODO: Revert to the following two lines in v0.23 + # X = joblib.load(samples_path) + # y = joblib.load(targets_path) if shuffle: ind = np.arange(X.shape[0]) diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py index 837a489e7212c..f50f49f85ab6f 100644 --- a/sklearn/datasets/kddcup99.py +++ b/sklearn/datasets/kddcup99.py @@ -20,6 +20,7 @@ from .base import _fetch_remote from .base import get_data_home from .base import RemoteFileMetadata +from .base import _refresh_cache from ..utils import Bunch from ..utils import check_random_state from ..utils import shuffle as shuffle_method @@ -292,8 +293,10 @@ def _fetch_brute_kddcup99(data_home=None, try: X, y except NameError: - X = joblib.load(samples_path) - y = joblib.load(targets_path) + X, y = _refresh_cache([samples_path, targets_path], 0) + # TODO: Revert to the following two lines in v0.23 + # X = joblib.load(samples_path) + # y = joblib.load(targets_path) return Bunch(data=X, target=y) diff --git a/sklearn/datasets/olivetti_faces.py b/sklearn/datasets/olivetti_faces.py index a52f90414e104..24eeb7927abcf 100644 --- a/sklearn/datasets/olivetti_faces.py +++ b/sklearn/datasets/olivetti_faces.py @@ -24,6 +24,7 @@ from .base import _fetch_remote from .base import RemoteFileMetadata from .base import _pkl_filepath +from .base import _refresh_cache from ..utils import check_random_state, Bunch # The original data can be found at: @@ -107,7 +108,9 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0, joblib.dump(faces, filepath, compress=6) del mfile else: - faces = joblib.load(filepath) + faces = _refresh_cache([filepath], 6) + # TODO: Revert to the following line in v0.23 + # faces = joblib.load(filepath) # We want floating point data, but float32 is enough (there is only # one byte of precision in the original uint8s anyway) diff --git a/sklearn/datasets/rcv1.py b/sklearn/datasets/rcv1.py index c95cf1d1be75a..c000acf13e249 100644 --- a/sklearn/datasets/rcv1.py +++ b/sklearn/datasets/rcv1.py @@ -22,6 +22,7 @@ from .base import _pkl_filepath from .base import _fetch_remote from .base import RemoteFileMetadata +from .base import _refresh_cache from .svmlight_format import load_svmlight_files from ..utils import shuffle as shuffle_ from ..utils import Bunch @@ -189,8 +190,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True, f.close() remove(f.name) else: - X = joblib.load(samples_path) - sample_id = joblib.load(sample_id_path) + X, sample_id = _refresh_cache([samples_path, sample_id_path], 9) + # TODO: Revert to the following two lines in v0.23 + # X = joblib.load(samples_path) + # sample_id = joblib.load(sample_id_path) # load target (y), categories, and sample_id_bis if download_if_missing and (not exists(sample_topics_path) or @@ -243,8 +246,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True, joblib.dump(y, sample_topics_path, compress=9) joblib.dump(categories, topics_path, compress=9) else: - y = joblib.load(sample_topics_path) - categories = joblib.load(topics_path) + y, categories = _refresh_cache([sample_topics_path, topics_path], 9) + # TODO: Revert to the following two lines in v0.23 + # y = joblib.load(sample_topics_path) + # categories = joblib.load(topics_path) if subset == 'all': pass diff --git a/sklearn/datasets/species_distributions.py b/sklearn/datasets/species_distributions.py index f9a04f92b8486..82ae22129ab9b 100644 --- a/sklearn/datasets/species_distributions.py +++ b/sklearn/datasets/species_distributions.py @@ -51,6 +51,7 @@ from .base import RemoteFileMetadata from ..utils import Bunch from .base import _pkl_filepath +from .base import _refresh_cache # The original data can be found at: # https://biodiversityinformatics.amnh.org/open_source/maxent/samples.zip @@ -259,6 +260,8 @@ def fetch_species_distributions(data_home=None, **extra_params) joblib.dump(bunch, archive_path, compress=9) else: - bunch = joblib.load(archive_path) + bunch = _refresh_cache([archive_path], 9) + # TODO: Revert to the following line in v0.23 + # bunch = joblib.load(archive_path) return bunch diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 1b58115d337e7..5e0af0318729f 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -8,6 +8,7 @@ from functools import partial import pytest +import joblib import numpy as np from sklearn.datasets import get_data_home @@ -23,6 +24,7 @@ from sklearn.datasets import load_boston from sklearn.datasets import load_wine from sklearn.datasets.base import Bunch +from sklearn.datasets.base import _refresh_cache from sklearn.datasets.tests.test_common import check_return_X_y from sklearn.externals._pilutil import pillow_installed @@ -276,3 +278,55 @@ def test_bunch_dir(): # check that dir (important for autocomplete) shows attributes data = load_iris() assert "data" in dir(data) + + +def test_refresh_cache(monkeypatch): + # uses pytests monkeypatch fixture + # https://docs.pytest.org/en/latest/monkeypatch.html + + def _load_warn(*args, **kwargs): + # raise the warning from "externals.joblib.__init__.py" + # this is raised when a file persisted by the old joblib is loaded now + msg = ("sklearn.externals.joblib is deprecated in 0.21 and will be " + "removed in 0.23. Please import this functionality directly " + "from joblib, which can be installed with: pip install joblib. " + "If this warning is raised when loading pickled models, you " + "may need to re-serialize those models with scikit-learn " + "0.21+.") + warnings.warn(msg, DeprecationWarning) + return 0 + + def _load_warn_unrelated(*args, **kwargs): + warnings.warn("unrelated warning", DeprecationWarning) + return 0 + + def _dump_safe(*args, **kwargs): + pass + + def _dump_raise(*args, **kwargs): + # this happens if the file is read-only and joblib.dump fails to write + # on it. + raise IOError() + + # test if the dataset spesific warning is raised if load raises the joblib + # warning, and dump fails to dump with new joblib + monkeypatch.setattr(joblib, "load", _load_warn) + monkeypatch.setattr(joblib, "dump", _dump_raise) + msg = "This dataset will stop being loadable in scikit-learn" + with pytest.warns(DeprecationWarning, match=msg): + _refresh_cache('test', 0) + + # make sure no warning is raised if load raises the warning, but dump + # manages to dump the new data + monkeypatch.setattr(joblib, "load", _load_warn) + monkeypatch.setattr(joblib, "dump", _dump_safe) + with pytest.warns(None) as warns: + _refresh_cache('test', 0) + assert len(warns) == 0 + + # test if an unrelated warning is still passed through and not suppressed + # by _refresh_cache + monkeypatch.setattr(joblib, "load", _load_warn_unrelated) + monkeypatch.setattr(joblib, "dump", _dump_safe) + with pytest.warns(DeprecationWarning, match="unrelated warning"): + _refresh_cache('test', 0)