diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index bcd163caf738d..c910d1c04e217 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -22,7 +22,7 @@ from ..utils import check_random_state from ..utils import check_pandas_support from ..utils.fixes import _open_binary, _open_text, _read_text, _contents -from ..utils._param_validation import validate_params, Interval +from ..utils._param_validation import validate_params, Interval, StrOptions import numpy as np @@ -1252,6 +1252,11 @@ def load_sample_images(): return Bunch(images=images, filenames=filenames, DESCR=descr) +@validate_params( + { + "image_name": [StrOptions({"china.jpg", "flower.jpg"})], + } +) def load_sample_image(image_name): """Load the numpy array of a single sample image. diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index d810e99db5878..9394b33f88f57 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -221,12 +221,6 @@ def test_load_sample_image(): warnings.warn("Could not load sample images, PIL is not available.") -def test_load_missing_sample_image_error(): - pytest.importorskip("PIL") - with pytest.raises(AttributeError): - load_sample_image("blop.jpg") - - def test_load_diabetes_raw(): """Test to check that we load a scaled version by default but that we can get an unscaled version when setting `scaled=False`.""" diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2bb6846dc4cbf..0d57ff7963ab3 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -135,6 +135,7 @@ def _check_function_param_validation( "sklearn.datasets.load_digits", "sklearn.datasets.load_iris", "sklearn.datasets.load_linnerud", + "sklearn.datasets.load_sample_image", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", "sklearn.datasets.load_wine",