From 79b198f699655f758f94b5cc0cd9993bc18e8b40 Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Fri, 9 Sep 2022 16:15:31 +0500 Subject: [PATCH 1/4] Fix to ensure that GaussianProcessRegressor predict method does not modify input --- sklearn/gaussian_process/_gpr.py | 2 +- .../tests/_custom_min_t_kernel.py | 50 +++++++++++++++++++ sklearn/gaussian_process/tests/test_gpr.py | 18 +++++++ 3 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 sklearn/gaussian_process/tests/_custom_min_t_kernel.py diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 4777d0d80627c..9afdad451de8e 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -435,7 +435,7 @@ def predict(self, X, return_std=False, return_cov=False): # Compute variance of predictive distribution # Use einsum to avoid explicitly forming the large matrix # V^T @ V just to extract its diagonal afterward. - y_var = self.kernel_.diag(X) + y_var = self.kernel_.diag(np.copy(X)) y_var -= np.einsum("ij,ji->i", V.T, V) # Check if any of the variances is negative because of diff --git a/sklearn/gaussian_process/tests/_custom_min_t_kernel.py b/sklearn/gaussian_process/tests/_custom_min_t_kernel.py new file mode 100644 index 0000000000000..988aff36551cb --- /dev/null +++ b/sklearn/gaussian_process/tests/_custom_min_t_kernel.py @@ -0,0 +1,50 @@ +import numpy as np +from sklearn.gaussian_process.kernels import Hyperparameter, Kernel + + +class CustomMinT(Kernel): + """ + A custom kernel that has a diag method that returns the first column of the + input matrix X. This is a helper for the test to check that the input + matrix X is not mutated. + """ + + def __init__(self, sigma_0=1.0, sigma_0_bounds=(0.01, 10)): + self.sigma_0 = sigma_0 + self.sigma_0_bounds = sigma_0_bounds + + @property + def hyperparameter_sigma_0(self): + return Hyperparameter("sigma_0", "numeric", self.sigma_0_bounds) + + def __call__(self, X, Y=None, eval_gradient=False): + if Y is not None and eval_gradient: + raise ValueError("Gradient can only be evaluated when Y is None.") + + X = np.atleast_2d(X) + ones_x = np.ones_like(X) + + if Y is None: + kc, kr = X * ones_x.T, ones_x * X.T + else: + ones_y = np.ones_like(Y) + kc, kr = X * ones_y.T, ones_x * Y.T + + kcr = np.concatenate((kc[..., None], kr[..., None]), axis=-1) + k = np.min(kcr, axis=-1) + + if eval_gradient: + if not self.hyperparameter_sigma_0.fixed: + k_gradient = np.empty((k.shape[0], k.shape[1], 1)) + k_gradient[..., 0] = self.sigma_0 + return k, k_gradient + else: + return k, np.empty((X.shape[0], X.shape[0], 0)) + else: + return k + + def diag(self, X): + return X[:, 0] + + def is_stationary(self): + return False diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 784aa88d6487c..608dafe3fabe4 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -16,6 +16,7 @@ from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared +from sklearn.gaussian_process.tests._custom_min_t_kernel import CustomMinT from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning from sklearn.utils._testing import ( @@ -23,6 +24,7 @@ assert_almost_equal, assert_array_almost_equal, assert_allclose, + assert_array_equal, ) @@ -767,3 +769,19 @@ def test_sample_y_shapes(normalize_y, n_targets): y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) assert y_samples.shape == y_test_shape + + +def test_gpr_predict_input_not_modified(): + """ + Check that the input X is not modified by the predict method of the + GaussianProcessRegressor when setting return_std=True. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/24340 + """ + gpr = GaussianProcessRegressor(kernel=CustomMinT()).fit(X, y) + + original_x2 = np.copy(X2) + _, _ = gpr.predict(X2, return_std=True) + + assert_array_equal(original_x2, X2) From 6e6152388d871779a9678ec9979dd36b1ced1435 Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Fri, 9 Sep 2022 17:45:36 +0500 Subject: [PATCH 2/4] Added entry in change log and made code improvements --- doc/whats_new/v1.2.rst | 6 +- sklearn/gaussian_process/_gpr.py | 2 +- .../tests/_custom_min_t_kernel.py | 50 ---------------- sklearn/gaussian_process/tests/test_gpr.py | 57 ++++++++++++++++++- 4 files changed, 61 insertions(+), 54 deletions(-) delete mode 100644 sklearn/gaussian_process/tests/_custom_min_t_kernel.py diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index cde3949a2d410..42ea7ae71dba7 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -218,6 +218,10 @@ Changelog - |Fix| Fix :class:`gaussian_process.kernels.Matern` gradient computation with `nu=0.5` for PyPy (and possibly other non CPython interpreters). :pr:`24245` by :user:`Loïc Estève `. +- |Fix| The `fit` method of :class:`gaussian_process.GaussianProcessRegressor` + will not modify the input X in case a custom kernel is used, with a `diag` + method that returns part of the input X. :pr:`24405` + by :user:`Omar Salman `. :mod:`sklearn.linear_model` ........................... @@ -272,7 +276,7 @@ Changelog - |Fix| Allows `csr_matrix` as input for parameter: `y_true` of the :func:`metrics.label_ranking_average_precision_score` metric. :pr:`23442` by :user:`Sean Atukorala ` - + - |Fix| :func:`metrics.ndcg_score` will now trigger a warning when the `y_true` value contains a negative value. Users may still use negative values, but the result may not be between 0 and 1. Starting in v1.4, passing in negative diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 9afdad451de8e..c0a8dc71b7352 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -435,7 +435,7 @@ def predict(self, X, return_std=False, return_cov=False): # Compute variance of predictive distribution # Use einsum to avoid explicitly forming the large matrix # V^T @ V just to extract its diagonal afterward. - y_var = self.kernel_.diag(np.copy(X)) + y_var = self.kernel_.diag(X).copy() y_var -= np.einsum("ij,ji->i", V.T, V) # Check if any of the variances is negative because of diff --git a/sklearn/gaussian_process/tests/_custom_min_t_kernel.py b/sklearn/gaussian_process/tests/_custom_min_t_kernel.py deleted file mode 100644 index 988aff36551cb..0000000000000 --- a/sklearn/gaussian_process/tests/_custom_min_t_kernel.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -from sklearn.gaussian_process.kernels import Hyperparameter, Kernel - - -class CustomMinT(Kernel): - """ - A custom kernel that has a diag method that returns the first column of the - input matrix X. This is a helper for the test to check that the input - matrix X is not mutated. - """ - - def __init__(self, sigma_0=1.0, sigma_0_bounds=(0.01, 10)): - self.sigma_0 = sigma_0 - self.sigma_0_bounds = sigma_0_bounds - - @property - def hyperparameter_sigma_0(self): - return Hyperparameter("sigma_0", "numeric", self.sigma_0_bounds) - - def __call__(self, X, Y=None, eval_gradient=False): - if Y is not None and eval_gradient: - raise ValueError("Gradient can only be evaluated when Y is None.") - - X = np.atleast_2d(X) - ones_x = np.ones_like(X) - - if Y is None: - kc, kr = X * ones_x.T, ones_x * X.T - else: - ones_y = np.ones_like(Y) - kc, kr = X * ones_y.T, ones_x * Y.T - - kcr = np.concatenate((kc[..., None], kr[..., None]), axis=-1) - k = np.min(kcr, axis=-1) - - if eval_gradient: - if not self.hyperparameter_sigma_0.fixed: - k_gradient = np.empty((k.shape[0], k.shape[1], 1)) - k_gradient[..., 0] = self.sigma_0 - return k, k_gradient - else: - return k, np.empty((X.shape[0], X.shape[0], 0)) - else: - return k - - def diag(self, X): - return X[:, 0] - - def is_stationary(self): - return False diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 608dafe3fabe4..adcee86d21ad5 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -14,9 +14,14 @@ import pytest from sklearn.gaussian_process import GaussianProcessRegressor -from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel +from sklearn.gaussian_process.kernels import ( + RBF, + ConstantKernel as C, + WhiteKernel, + Kernel, + Hyperparameter, +) from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared -from sklearn.gaussian_process.tests._custom_min_t_kernel import CustomMinT from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning from sklearn.utils._testing import ( @@ -771,6 +776,54 @@ def test_sample_y_shapes(normalize_y, n_targets): assert y_samples.shape == y_test_shape +class CustomMinT(Kernel): + """ + A custom kernel that has a diag method that returns the first column of the + input matrix X. This is a helper for the test to check that the input + matrix X is not mutated. + """ + + def __init__(self, sigma_0=1.0, sigma_0_bounds=(0.01, 10)): + self.sigma_0 = sigma_0 + self.sigma_0_bounds = sigma_0_bounds + + @property + def hyperparameter_sigma_0(self): + return Hyperparameter("sigma_0", "numeric", self.sigma_0_bounds) + + def __call__(self, X, Y=None, eval_gradient=False): + if Y is not None and eval_gradient: + raise ValueError("Gradient can only be evaluated when Y is None.") + + X = np.atleast_2d(X) + ones_x = np.ones_like(X) + + if Y is None: + kc, kr = X * ones_x.T, ones_x * X.T + else: + ones_y = np.ones_like(Y) + kc, kr = X * ones_y.T, ones_x * Y.T + + kcr = np.concatenate((kc[..., None], kr[..., None]), axis=-1) + k = np.min(kcr, axis=-1) + + if eval_gradient: + if not self.hyperparameter_sigma_0.fixed: + k_gradient = np.empty((k.shape[0], k.shape[1], 1)) + k_gradient[..., 0] = self.sigma_0 + return k, k_gradient + else: + return k, np.empty((X.shape[0], X.shape[0], 0)) + else: + return k + + def diag(self, X): + return X[:, 0] + + def is_stationary(self): + return False + + def test_gpr_predict_input_not_modified(): """ Check that the input X is not modified by the predict method of the From f7ec0e0bb297a1fb250ee2452577a13b7a277c5f Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Fri, 9 Sep 2022 17:56:19 +0500 Subject: [PATCH 3/4] Adjustment in change log --- doc/whats_new/v1.2.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 42ea7ae71dba7..95c349ed13590 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -218,6 +218,7 @@ Changelog - |Fix| Fix :class:`gaussian_process.kernels.Matern` gradient computation with `nu=0.5` for PyPy (and possibly other non CPython interpreters). :pr:`24245` by :user:`Loïc Estève `. + - |Fix| The `fit` method of :class:`gaussian_process.GaussianProcessRegressor` will not modify the input X in case a custom kernel is used, with a `diag` method that returns part of the input X. :pr:`24405` From 32715c1f31a828c700bd293606efa965f2a226c0 Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Fri, 9 Sep 2022 22:07:01 +0500 Subject: [PATCH 4/4] Simplify the Custom kernel for the test --- sklearn/gaussian_process/tests/test_gpr.py | 43 +--------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 61169ae590bc9..c03778958a3ad 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,8 +18,6 @@ RBF, ConstantKernel as C, WhiteKernel, - Kernel, - Hyperparameter, ) from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel @@ -775,53 +773,16 @@ def test_sample_y_shapes(normalize_y, n_targets): assert y_samples.shape == y_test_shape -class CustomMinT(Kernel): +class CustomKernel(C): """ A custom kernel that has a diag method that returns the first column of the input matrix X. This is a helper for the test to check that the input matrix X is not mutated. """ - def __init__(self, sigma_0=1.0, sigma_0_bounds=(0.01, 10)): - self.sigma_0 = sigma_0 - self.sigma_0_bounds = sigma_0_bounds - - @property - def hyperparameter_sigma_0(self): - return Hyperparameter("sigma_0", "numeric", self.sigma_0_bounds) - - def __call__(self, X, Y=None, eval_gradient=False): - if Y is not None and eval_gradient: - raise ValueError("Gradient can only be evaluated when Y is None.") - - X = np.atleast_2d(X) - ones_x = np.ones_like(X) - - if Y is None: - kc, kr = X * ones_x.T, ones_x * X.T - else: - ones_y = np.ones_like(Y) - kc, kr = X * ones_y.T, ones_x * Y.T - - kcr = np.concatenate((kc[..., None], kr[..., None]), axis=-1) - k = np.min(kcr, axis=-1) - - if eval_gradient: - if not self.hyperparameter_sigma_0.fixed: - k_gradient = np.empty((k.shape[0], k.shape[1], 1)) - k_gradient[..., 0] = self.sigma_0 - return k, k_gradient - else: - return k, np.empty((X.shape[0], X.shape[0], 0)) - else: - return k - def diag(self, X): return X[:, 0] - def is_stationary(self): - return False - def test_gpr_predict_input_not_modified(): """ @@ -831,7 +792,7 @@ def test_gpr_predict_input_not_modified(): Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/24340 """ - gpr = GaussianProcessRegressor(kernel=CustomMinT()).fit(X, y) + gpr = GaussianProcessRegressor(kernel=CustomKernel()).fit(X, y) X2_copy = np.copy(X2) _, _ = gpr.predict(X2, return_std=True)