From 6093daf178732faf1bff8d3d3259ca6e6f76597c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 27 Dec 2021 23:31:08 -0500 Subject: [PATCH 01/11] ENH Uses gaussian mixture --- sklearn/mixture/_base.py | 20 ++-- sklearn/mixture/_gaussian_mixture.py | 55 ++++++---- sklearn/utils/_array_api.py | 154 +++++++++++++++++++++++++++ sklearn/utils/validation.py | 27 +++-- 4 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 sklearn/utils/_array_api.py diff --git a/sklearn/mixture/_base.py b/sklearn/mixture/_base.py index bbe9699859ded..b0b7436046998 100644 --- a/sklearn/mixture/_base.py +++ b/sklearn/mixture/_base.py @@ -9,13 +9,13 @@ from time import time import numpy as np -from scipy.special import logsumexp from .. import cluster from ..base import BaseEstimator from ..base import DensityMixin from ..exceptions import ConvergenceWarning from ..utils import check_random_state +from ..utils._array_api import get_namespace, logsumexp from ..utils.validation import check_is_fitted @@ -136,6 +136,7 @@ def _initialize_parameters(self, X, random_state): used for the method chosen to initialize the parameters. """ n_samples, _ = X.shape + np, _ = get_namespace(X) if self.init_params == "kmeans": resp = np.zeros((n_samples, self.n_components)) @@ -149,7 +150,8 @@ def _initialize_parameters(self, X, random_state): resp[np.arange(n_samples), label] = 1 elif self.init_params == "random": resp = random_state.rand(n_samples, self.n_components) - resp /= resp.sum(axis=1)[:, np.newaxis] + resp = np.asarray(resp) + resp /= np.reshape(np.sum(resp, axis=1), (-1, 1)) else: raise ValueError( "Unimplemented initialization method '%s'" % self.init_params @@ -225,6 +227,7 @@ def fit_predict(self, X, y=None): labels : array, shape (n_samples,) Component labels. """ + np, _ = get_namespace(X) X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_min_samples=2) if X.shape[0] < self.n_components: raise ValueError( @@ -291,7 +294,7 @@ def fit_predict(self, X, y=None): # for any value of max_iter and tol (and any random_state). _, log_resp = self._e_step(X) - return log_resp.argmax(axis=1) + return np.argmax(log_resp, axis=1) def _e_step(self, X): """E step. @@ -309,6 +312,7 @@ def _e_step(self, X): Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in X. """ + np, _ = get_namespace(X) log_prob_norm, log_resp = self._estimate_log_prob_resp(X) return np.mean(log_prob_norm), log_resp @@ -527,11 +531,15 @@ def _estimate_log_prob_resp(self, X): log_responsibilities : array, shape (n_samples, n_components) logarithm of the responsibilities """ + np, is_array_api = get_namespace(X) weighted_log_prob = self._estimate_weighted_log_prob(X) log_prob_norm = logsumexp(weighted_log_prob, axis=1) - with np.errstate(under="ignore"): - # ignore underflow - log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] + if is_array_api: + log_resp = weighted_log_prob - np.reshape(log_prob_norm, (-1, 1)) + else: + with np.errstate(under="ignore"): + # ignore underflow + log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] return log_prob_norm, log_resp def _print_verbose_msg_init_beg(self, n_init): diff --git a/sklearn/mixture/_gaussian_mixture.py b/sklearn/mixture/_gaussian_mixture.py index d710b0d018c4c..ae524a57bba0f 100644 --- a/sklearn/mixture/_gaussian_mixture.py +++ b/sklearn/mixture/_gaussian_mixture.py @@ -5,12 +5,16 @@ # License: BSD 3 clause import numpy as np +from math import log +from functools import partial from scipy import linalg +import scipy from ._base import BaseMixture, _check_shape from ..utils import check_array from ..utils.extmath import row_norms +from ..utils._array_api import get_namespace ############################################################################### @@ -171,12 +175,13 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): covariances : array, shape (n_components, n_features, n_features) The covariance matrix of the current components. """ + np, _ = get_namespace(resp, X, nk) n_components, n_features = means.shape covariances = np.empty((n_components, n_features, n_features)) for k in range(n_components): - diff = X - means[k] - covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k] - covariances[k].flat[:: n_features + 1] += reg_covar + diff = X - means[k, :] + covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k] + np.reshape(covariances[k, :, :], (-1,))[:: n_features + 1] += reg_covar return covariances @@ -286,8 +291,9 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): The covariance matrix of the current components. The shape depends of the covariance_type. """ - nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps - means = np.dot(resp.T, X) / nk[:, np.newaxis] + np, _ = get_namespace(X, resp) + nk = np.sum(resp, axis=0) + 10 * np.finfo(resp.dtype).eps + means = resp.T @ X / np.reshape(nk, (-1, 1)) covariances = { "full": _estimate_gaussian_covariances_full, "tied": _estimate_gaussian_covariances_tied, @@ -321,27 +327,31 @@ def _compute_precision_cholesky(covariances, covariance_type): "or collapsed samples). Try to decrease the number of components, " "or increase reg_covar." ) + np, is_array_api = get_namespace(covariances) + if is_array_api: + cholesky = np.linalg.cholesky + solve = np.linalg.solve + else: + cholesky = partial(scipy.linalg.cholesky, lower=True) + solve = partial(scipy.linalg.solve_triangular, lower=True) if covariance_type == "full": n_components, n_features, _ = covariances.shape precisions_chol = np.empty((n_components, n_features, n_features)) - for k, covariance in enumerate(covariances): + for k in range(n_components): try: - cov_chol = linalg.cholesky(covariance, lower=True) + cov_chol = cholesky(covariances[k, :, :]) except linalg.LinAlgError: raise ValueError(estimate_precision_error_message) - precisions_chol[k] = linalg.solve_triangular( - cov_chol, np.eye(n_features), lower=True - ).T + precisions_chol[k, :, :] = solve(cov_chol, np.eye(n_features)).T + elif covariance_type == "tied": _, n_features = covariances.shape try: - cov_chol = linalg.cholesky(covariances, lower=True) + cov_chol = cholesky(covariances) except linalg.LinAlgError: raise ValueError(estimate_precision_error_message) - precisions_chol = linalg.solve_triangular( - cov_chol, np.eye(n_features), lower=True - ).T + precisions_chol = linalg.solve(cov_chol, np.eye(n_features)).T else: if np.any(np.less_equal(covariances, 0.0)): raise ValueError(estimate_precision_error_message) @@ -373,11 +383,11 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): log_det_precision_chol : array-like of shape (n_components,) The determinant of the precision matrix for each component. """ + np, _ = get_namespace(matrix_chol) if covariance_type == "full": n_components, _, _ = matrix_chol.shape - log_det_chol = np.sum( - np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1 - ) + matrix_col_reshape = np.reshape(matrix_chol, (n_components, -1)) + log_det_chol = np.sum(np.log(matrix_col_reshape[:, :: n_features + 1]), axis=1) elif covariance_type == "tied": log_det_chol = np.sum(np.log(np.diag(matrix_chol))) @@ -413,6 +423,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): ------- log_prob : array, shape (n_samples, n_components) """ + np, _ = get_namespace(X, means, precisions_chol) n_samples, n_features = X.shape n_components, _ = means.shape # The determinant of the precision matrix from the Cholesky decomposition @@ -423,8 +434,10 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): if covariance_type == "full": log_prob = np.empty((n_samples, n_components)) - for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): - y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) + for k in range(n_components): + mu = means[k, :] + prec_chol = precisions_chol[k, :, :] + y = X @ prec_chol - mu @ prec_chol log_prob[:, k] = np.sum(np.square(y), axis=1) elif covariance_type == "tied": @@ -450,7 +463,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): ) # Since we are using the precision of the Cholesky decomposition, # `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol` - return -0.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det + return -0.5 * (n_features * log(2 * np.pi) + log_prob) + log_det class GaussianMixture(BaseMixture): @@ -742,6 +755,7 @@ def _m_step(self, X, log_resp): the point of each sample in X. """ n_samples, _ = X.shape + np, _ = get_namespace(X, log_resp) self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters( X, np.exp(log_resp), self.reg_covar, self.covariance_type ) @@ -756,6 +770,7 @@ def _estimate_log_prob(self, X): ) def _estimate_log_weights(self): + np, _ = get_namespace(self.weights_) return np.log(self.weights_) def _compute_lower_bound(self, _, log_prob_norm): diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py new file mode 100644 index 0000000000000..aa424343f5776 --- /dev/null +++ b/sklearn/utils/_array_api.py @@ -0,0 +1,154 @@ +"""Tools to support array_api.""" +import numpy as np +from scipy.special import logsumexp as sp_logsumexp + + +def get_namespace(*xs): + # `xs` contains one or more arrays, or possibly Python scalars (accepting + # those is a matter of taste, but doesn't seem unreasonable). + namespaces = { + x.__array_namespace__() if hasattr(x, "__array_namespace__") else None + for x in xs + if not isinstance(x, (bool, int, float, complex)) + } + + if not namespaces: + # one could special-case np.ndarray above or use np.asarray here if + # older numpy versions need to be supported. + raise ValueError("Unrecognized array input") + + if len(namespaces) != 1: + raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") + + (xp,) = namespaces + if xp is None: + # Use numpy as default + return np, False + + return xp, True + + +def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): + """Compute the log of the sum of exponentials of input elements. + + Parameters + ---------- + a : array_like + Input array. + axis : None or int or tuple of ints, optional + Axis or axes over which the sum is taken. By default `axis` is None, + and all elements are summed. + + .. versionadded:: 0.11.0 + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result + will broadcast correctly against the original array. + + .. versionadded:: 0.15.0 + b : array-like, optional + Scaling factor for exp(`a`) must be of the same shape as `a` or + broadcastable to `a`. These values may be negative in order to + implement subtraction. + + .. versionadded:: 0.12.0 + return_sign : bool, optional + If this is set to True, the result will be a pair containing sign + information; if False, results that are negative will be returned + as NaN. Default is False (no sign information). + + .. versionadded:: 0.16.0 + + Returns + ------- + res : ndarray + The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically + more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))`` + is returned. + sgn : ndarray + If return_sign is True, this will be an array of floating-point + numbers matching res and +1, 0, or -1 depending on the sign + of the result. If False, only one result is returned. + + See Also + -------- + numpy.logaddexp, numpy.logaddexp2 + + Notes + ----- + NumPy has a logaddexp function which is very similar to `logsumexp`, but + only handles two arguments. `logaddexp.reduce` is similar to this + function, but may be less stable. + + Examples + -------- + >>> from scipy.special import logsumexp + >>> a = np.arange(10) + >>> np.log(np.sum(np.exp(a))) + 9.4586297444267107 + >>> logsumexp(a) + 9.4586297444267107 + + With weights + + >>> a = np.arange(10) + >>> b = np.arange(10, 0, -1) + >>> logsumexp(a, b=b) + 9.9170178533034665 + >>> np.log(np.sum(b*np.exp(a))) + 9.9170178533034647 + + Returning a sign flag + + >>> logsumexp([1,2],b=[1,-1],return_sign=True) + (1.5413248546129181, -1.0) + + Notice that `logsumexp` does not directly support masked arrays. To use it + on a masked array, convert the mask into zero weights: + + >>> a = np.ma.array([np.log(2), 2, np.log(3)], + ... mask=[False, True, False]) + >>> b = (~a.mask).astype(int) + >>> logsumexp(a.data, b=b), np.log(5) + 1.6094379124341005, 1.6094379124341005 + + """ + np, is_array_api = get_namespace(a) + if not is_array_api: + return sp_logsumexp( + a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign + ) + if b is not None: + a, b = np.broadcast_arrays(a, b) + if np.any(b == 0): + a = a + 0.0 # promote to at least float + a[b == 0] = -np.inf + + a_max = np.max(a, axis=axis, keepdims=True) + + if a_max.ndim > 0: + a_max[~np.isfinite(a_max)] = 0 + elif not np.isfinite(a_max): + a_max = 0 + + if b is not None: + b = np.asarray(b) + tmp = b * np.exp(a - a_max) + else: + tmp = np.exp(a - a_max) + + # suppress warnings about log of zero + s = np.sum(tmp, axis=axis, keepdims=keepdims) + if return_sign: + sgn = np.sign(s) + s *= sgn # /= makes more sense but we need zero -> zero + out = np.log(s) + + if not keepdims: + a_max = np.squeeze(a_max, axis=axis) + out += a_max + + if return_sign: + return out, sgn + else: + return out diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 5936415c776b8..6380b844daca3 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -15,6 +15,7 @@ import operator import numpy as np +import numpy import scipy.sparse as sp from inspect import signature, isclass, Parameter @@ -29,6 +30,7 @@ from ..exceptions import PositiveSpectrumWarning from ..exceptions import NotFittedError from ..exceptions import DataConversionWarning +from ..utils._array_api import get_namespace FLOAT_DTYPES = (np.float64, np.float32, np.float16) @@ -698,7 +700,7 @@ def check_array( array_converted : object The converted and validated array. """ - if isinstance(array, np.matrix): + if isinstance(array, numpy.matrix): warnings.warn( "np.matrix usage is deprecated in 1.0 and will raise a TypeError " "in 1.2. Please convert to a numpy array with np.asarray. For " @@ -706,6 +708,7 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html", # noqa FutureWarning, ) + np, is_array_api = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -821,7 +824,9 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - array = np.asarray(array, order=order) + if not is_array_api: + # array_api does not have order + array = np.asarray(array, order=order) if array.dtype.kind == "f": _assert_all_finite( array, @@ -830,9 +835,15 @@ def check_array( estimator_name=estimator_name, input_name=input_name, ) - array = array.astype(dtype, casting="unsafe", copy=False) + if is_array_api: + array = np.astype(dtype, copy=False) + else: + array = array.astype(dtype, casting="unsafe", copy=False) else: - array = np.asarray(array, order=order, dtype=dtype) + if is_array_api: + array = np.astype(array, dtype) + else: + array = np.asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: raise ValueError( "Complex data not supported\n{}\n".format(array) @@ -911,8 +922,12 @@ def check_array( % (n_features, array.shape, ensure_min_features, context) ) - if copy and np.may_share_memory(array, array_orig): - array = np.array(array, dtype=dtype, order=order) + if copy: + if not is_array_api: + if np.may_share_memory(array, array_orig): + array = np.array(array, dtype=dtype, order=order) + else: + array = np.asarray(array, dtype=dtype, copy=True) return array From ab1ab7a8ab53c68fa0814ba5b00b86832717b97b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 28 Dec 2021 11:33:26 -0500 Subject: [PATCH 02/11] ENH Use array_api in validation --- sklearn/utils/_array_api.py | 89 +++---------------------------------- sklearn/utils/validation.py | 8 +++- 2 files changed, 11 insertions(+), 86 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index aa424343f5776..ce7a47846d669 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -6,6 +6,8 @@ def get_namespace(*xs): # `xs` contains one or more arrays, or possibly Python scalars (accepting # those is a matter of taste, but doesn't seem unreasonable). + # Returns a tuple: (array_namespace, is_array_api) + namespaces = { x.__array_namespace__() if hasattr(x, "__array_namespace__") else None for x in xs @@ -29,95 +31,14 @@ def get_namespace(*xs): def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): - """Compute the log of the sum of exponentials of input elements. - - Parameters - ---------- - a : array_like - Input array. - axis : None or int or tuple of ints, optional - Axis or axes over which the sum is taken. By default `axis` is None, - and all elements are summed. - - .. versionadded:: 0.11.0 - keepdims : bool, optional - If this is set to True, the axes which are reduced are left in the - result as dimensions with size one. With this option, the result - will broadcast correctly against the original array. - - .. versionadded:: 0.15.0 - b : array-like, optional - Scaling factor for exp(`a`) must be of the same shape as `a` or - broadcastable to `a`. These values may be negative in order to - implement subtraction. - - .. versionadded:: 0.12.0 - return_sign : bool, optional - If this is set to True, the result will be a pair containing sign - information; if False, results that are negative will be returned - as NaN. Default is False (no sign information). - - .. versionadded:: 0.16.0 - - Returns - ------- - res : ndarray - The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically - more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))`` - is returned. - sgn : ndarray - If return_sign is True, this will be an array of floating-point - numbers matching res and +1, 0, or -1 depending on the sign - of the result. If False, only one result is returned. - - See Also - -------- - numpy.logaddexp, numpy.logaddexp2 - - Notes - ----- - NumPy has a logaddexp function which is very similar to `logsumexp`, but - only handles two arguments. `logaddexp.reduce` is similar to this - function, but may be less stable. - - Examples - -------- - >>> from scipy.special import logsumexp - >>> a = np.arange(10) - >>> np.log(np.sum(np.exp(a))) - 9.4586297444267107 - >>> logsumexp(a) - 9.4586297444267107 - - With weights - - >>> a = np.arange(10) - >>> b = np.arange(10, 0, -1) - >>> logsumexp(a, b=b) - 9.9170178533034665 - >>> np.log(np.sum(b*np.exp(a))) - 9.9170178533034647 - - Returning a sign flag - - >>> logsumexp([1,2],b=[1,-1],return_sign=True) - (1.5413248546129181, -1.0) - - Notice that `logsumexp` does not directly support masked arrays. To use it - on a masked array, convert the mask into zero weights: - - >>> a = np.ma.array([np.log(2), 2, np.log(3)], - ... mask=[False, True, False]) - >>> b = (~a.mask).astype(int) - >>> logsumexp(a.data, b=b), np.log(5) - 1.6094379124341005, 1.6094379124341005 - - """ np, is_array_api = get_namespace(a) + + # Use SciPy if a is an ndarray if not is_array_api: return sp_logsumexp( a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign ) + if b is not None: a, b = np.broadcast_arrays(a, b) if np.any(b == 0): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 6380b844daca3..356af6d85e48e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -96,9 +96,13 @@ def _assert_all_finite( # validation is also imported in extmath from .extmath import _safe_accumulator_op + np, is_array_api = get_namespace(X) + if _get_config()["assume_finite"]: return - X = np.asanyarray(X) + + if not is_array_api: + X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated @@ -142,7 +146,7 @@ def _assert_all_finite( # for object dtype data, we only check for NaNs (GH-13254) elif X.dtype == np.dtype("object") and not allow_nan: - if _object_dtype_isnan(X).any(): + if np.any(_object_dtype_isnan(X)): raise ValueError("Input contains NaN") From bd7a4e6d5616d57e14bc3ced6f02fa15862d46be Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 4 Jan 2022 23:02:41 -0500 Subject: [PATCH 03/11] ENH Adds global configuration option --- sklearn/_config.py | 17 +++++++++++++++-- sklearn/utils/_array_api.py | 4 ++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index c41c180012056..8786f66d6b4c5 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -9,6 +9,7 @@ "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), "print_changed_only": True, "display": "text", + "array_api_dispatch": False, } _threadlocal = threading.local() @@ -40,7 +41,11 @@ def get_config(): def set_config( - assume_finite=None, working_memory=None, print_changed_only=None, display=None + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + array_api_dispatch=None, ): """Set global scikit-learn configuration @@ -95,11 +100,18 @@ def set_config( local_config["print_changed_only"] = print_changed_only if display is not None: local_config["display"] = display + if array_api_dispatch is not None: + local_config["array_api_dispatch"] = array_api_dispatch @contextmanager def config_context( - *, assume_finite=None, working_memory=None, print_changed_only=None, display=None + *, + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + array_api_dispatch=None, ): """Context manager for global scikit-learn configuration. @@ -171,6 +183,7 @@ def config_context( working_memory=working_memory, print_changed_only=print_changed_only, display=display, + array_api_dispatch=array_api_dispatch, ) try: diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ce7a47846d669..de5acaf30d54d 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1,6 +1,7 @@ """Tools to support array_api.""" import numpy as np from scipy.special import logsumexp as sp_logsumexp +from .._config import get_config def get_namespace(*xs): @@ -8,6 +9,9 @@ def get_namespace(*xs): # those is a matter of taste, but doesn't seem unreasonable). # Returns a tuple: (array_namespace, is_array_api) + if not get_config()["array_api_dispatch"]: + return np, False + namespaces = { x.__array_namespace__() if hasattr(x, "__array_namespace__") else None for x in xs From 6124b007059307c5e8423019cbbfd8e68135aec3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 13 Jan 2022 15:14:31 -0500 Subject: [PATCH 04/11] ENH Reduce the usage of is_array_api --- sklearn/mixture/_base.py | 9 ++--- sklearn/mixture/_gaussian_mixture.py | 10 ++++-- sklearn/utils/_array_api.py | 54 +++++++++++++++++++++++++--- sklearn/utils/validation.py | 29 +++++---------- 4 files changed, 69 insertions(+), 33 deletions(-) diff --git a/sklearn/mixture/_base.py b/sklearn/mixture/_base.py index b0b7436046998..49ff6a86d53fb 100644 --- a/sklearn/mixture/_base.py +++ b/sklearn/mixture/_base.py @@ -531,15 +531,12 @@ def _estimate_log_prob_resp(self, X): log_responsibilities : array, shape (n_samples, n_components) logarithm of the responsibilities """ - np, is_array_api = get_namespace(X) + np, _ = get_namespace(X) weighted_log_prob = self._estimate_weighted_log_prob(X) log_prob_norm = logsumexp(weighted_log_prob, axis=1) - if is_array_api: + with np.errstate(under="ignore"): + # ignore underflow log_resp = weighted_log_prob - np.reshape(log_prob_norm, (-1, 1)) - else: - with np.errstate(under="ignore"): - # ignore underflow - log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] return log_prob_norm, log_resp def _print_verbose_msg_init_beg(self, n_init): diff --git a/sklearn/mixture/_gaussian_mixture.py b/sklearn/mixture/_gaussian_mixture.py index ae524a57bba0f..38e6c354d303b 100644 --- a/sklearn/mixture/_gaussian_mixture.py +++ b/sklearn/mixture/_gaussian_mixture.py @@ -175,13 +175,19 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): covariances : array, shape (n_components, n_features, n_features) The covariance matrix of the current components. """ - np, _ = get_namespace(resp, X, nk) + np, is_array_api = get_namespace(resp, X, nk) n_components, n_features = means.shape covariances = np.empty((n_components, n_features, n_features)) for k in range(n_components): diff = X - means[k, :] covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k] - np.reshape(covariances[k, :, :], (-1,))[:: n_features + 1] += reg_covar + + if is_array_api: + for i in range(n_features): + covariances[k, i, i] += reg_covar + else: + covariances[k].flat[:: n_features + 1] += reg_covar + return covariances diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index de5acaf30d54d..23886e83fd818 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1,8 +1,54 @@ """Tools to support array_api.""" -import numpy as np +import numpy from scipy.special import logsumexp as sp_logsumexp from .._config import get_config +from contextlib import nullcontext + + +# There are more clever ways to wrap the API to ignore kwargs, but I am writing them out +# explicitly for demonstration purposes +class _ArrayAPIWrapper: + def __init__(self, array_namespace): + self._array_namespace = array_namespace + + def __getattr__(self, name): + return getattr(self._array_namespace, name) + + def errstate(self, *args, **kwargs): + # errstate not in `array_api` + return nullcontext() + + def astype(self, x, dtype, copy=True, **kwargs): + # ignore parameters that is not supported by array-api + f = self._array_namespace.astype + return f(x, dtype, copy=copy) + + def asarray(self, obj, dtype=None, device=None, copy=None, **kwargs): + f = self._array_namespace.asarray + return f(obj, dtype=dtype, device=device, copy=copy) + + def array(self, obj, dtype=None, device=None, copy=True, **kwargs): + f = self._array_namespace.asarray + return f(obj, dtype=dtype, device=device, copy=copy) + + def asanyarray(self, obj, *args, **kwargs): + # no-op for now + return obj + + def may_share_memory(self, *args, **kwargs): + # The safe choice is to return True all the time + return True + + +class _NumPyApiWrapper: + def __getattr__(self, name): + return getattr(numpy, name) + + def astype(self, x, dtype, *args, **kwargs): + # astype is not defined in the top level numpy namespace + return x.astype(dtype, *args, **kwargs) + def get_namespace(*xs): # `xs` contains one or more arrays, or possibly Python scalars (accepting @@ -10,7 +56,7 @@ def get_namespace(*xs): # Returns a tuple: (array_namespace, is_array_api) if not get_config()["array_api_dispatch"]: - return np, False + return _NumPyApiWrapper(), False namespaces = { x.__array_namespace__() if hasattr(x, "__array_namespace__") else None @@ -29,9 +75,9 @@ def get_namespace(*xs): (xp,) = namespaces if xp is None: # Use numpy as default - return np, False + return _NumPyApiWrapper(), False - return xp, True + return _ArrayAPIWrapper(xp), True def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 356af6d85e48e..2bbda21914e4f 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -96,13 +96,12 @@ def _assert_all_finite( # validation is also imported in extmath from .extmath import _safe_accumulator_op - np, is_array_api = get_namespace(X) + np, _ = get_namespace(X) if _get_config()["assume_finite"]: return - if not is_array_api: - X = np.asanyarray(X) + X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated @@ -712,7 +711,7 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html", # noqa FutureWarning, ) - np, is_array_api = get_namespace(array) + np, _ = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -828,9 +827,7 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - if not is_array_api: - # array_api does not have order - array = np.asarray(array, order=order) + array = np.asarray(array, order=order) if array.dtype.kind == "f": _assert_all_finite( array, @@ -839,15 +836,9 @@ def check_array( estimator_name=estimator_name, input_name=input_name, ) - if is_array_api: - array = np.astype(dtype, copy=False) - else: - array = array.astype(dtype, casting="unsafe", copy=False) + array = np.astype(dtype, casting="unsafe", copy=False) else: - if is_array_api: - array = np.astype(array, dtype) - else: - array = np.asarray(array, order=order, dtype=dtype) + array = np.asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: raise ValueError( "Complex data not supported\n{}\n".format(array) @@ -926,12 +917,8 @@ def check_array( % (n_features, array.shape, ensure_min_features, context) ) - if copy: - if not is_array_api: - if np.may_share_memory(array, array_orig): - array = np.array(array, dtype=dtype, order=order) - else: - array = np.asarray(array, dtype=dtype, copy=True) + if copy and np.may_share_memory(array, array_orig): + array = np.array(array, dtype=dtype, order=order) return array From 19abb1a75ee2d151946fb0f9fab7165370e15972 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 15:19:08 -0500 Subject: [PATCH 05/11] CLN Become more array-api like --- sklearn/_config.py | 10 +++ sklearn/mixture/_base.py | 26 +++---- sklearn/mixture/_gaussian_mixture.py | 61 +++++++-------- .../mixture/tests/test_gaussian_mixture.py | 30 ++++++++ sklearn/tests/test_config.py | 3 + sklearn/utils/_array_api.py | 74 +++++++++---------- sklearn/utils/validation.py | 35 +++++---- 7 files changed, 141 insertions(+), 98 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 8786f66d6b4c5..44a6be3d885cb 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -85,6 +85,11 @@ def set_config( .. versionadded:: 0.23 + array_api_dispatch : bool, default=None + Configure scikit-learn to use Array-API. Global default: False + + .. versionadded:: 1.2 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -150,6 +155,11 @@ def config_context( .. versionadded:: 0.23 + array_api_dispatch : bool, default=None + Configure scikit-learn to use Array-API. Global default: False + + .. versionadded:: 1.2 + Yields ------ None. diff --git a/sklearn/mixture/_base.py b/sklearn/mixture/_base.py index c74cec9dc9ae4..054e92c0b2933 100644 --- a/sklearn/mixture/_base.py +++ b/sklearn/mixture/_base.py @@ -136,7 +136,6 @@ def _initialize_parameters(self, X, random_state): used for the method chosen to initialize the parameters. """ n_samples, _ = X.shape - np, _ = get_namespace(X) if self.init_params == "kmeans": resp = np.zeros((n_samples, self.n_components)) @@ -149,9 +148,10 @@ def _initialize_parameters(self, X, random_state): ) resp[np.arange(n_samples), label] = 1 elif self.init_params == "random": + xp, _ = get_namespace(X) resp = random_state.uniform(size=(n_samples, self.n_components)) - resp = np.asarray(resp) - resp /= np.reshape(np.sum(resp, axis=1), (-1, 1)) + resp = xp.asarray(resp) + resp /= xp.reshape(xp.sum(resp, axis=1), (-1, 1)) else: raise ValueError( "Unimplemented initialization method '%s'" % self.init_params @@ -227,8 +227,8 @@ def fit_predict(self, X, y=None): labels : array, shape (n_samples,) Component labels. """ - np, _ = get_namespace(X) - X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_min_samples=2) + xp, _ = get_namespace(X) + X = self._validate_data(X, dtype=[xp.float64, xp.float32], ensure_min_samples=2) if X.shape[0] < self.n_components: raise ValueError( "Expected n_samples >= n_components " @@ -241,7 +241,7 @@ def fit_predict(self, X, y=None): do_init = not (self.warm_start and hasattr(self, "converged_")) n_init = self.n_init if do_init else 1 - max_lower_bound = -np.inf + max_lower_bound = -xp.inf self.converged_ = False random_state = check_random_state(self.random_state) @@ -253,7 +253,7 @@ def fit_predict(self, X, y=None): if do_init: self._initialize_parameters(X, random_state) - lower_bound = -np.inf if do_init else self.lower_bound_ + lower_bound = -xp.inf if do_init else self.lower_bound_ for n_iter in range(1, self.max_iter + 1): prev_lower_bound = lower_bound @@ -271,7 +271,7 @@ def fit_predict(self, X, y=None): self._print_verbose_msg_init_end(lower_bound) - if lower_bound > max_lower_bound or max_lower_bound == -np.inf: + if lower_bound > max_lower_bound or max_lower_bound == -xp.inf: max_lower_bound = lower_bound best_params = self._get_parameters() best_n_iter = n_iter @@ -294,7 +294,7 @@ def fit_predict(self, X, y=None): # for any value of max_iter and tol (and any random_state). _, log_resp = self._e_step(X) - return np.argmax(log_resp, axis=1) + return xp.argmax(log_resp, axis=1) def _e_step(self, X): """E step. @@ -312,9 +312,9 @@ def _e_step(self, X): Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in X. """ - np, _ = get_namespace(X) + xp, _ = get_namespace(X) log_prob_norm, log_resp = self._estimate_log_prob_resp(X) - return np.mean(log_prob_norm), log_resp + return xp.mean(log_prob_norm), log_resp @abstractmethod def _m_step(self, X, log_resp): @@ -533,12 +533,12 @@ def _estimate_log_prob_resp(self, X): log_responsibilities : array, shape (n_samples, n_components) logarithm of the responsibilities """ - np, _ = get_namespace(X) + xp, _ = get_namespace(X) weighted_log_prob = self._estimate_weighted_log_prob(X) log_prob_norm = logsumexp(weighted_log_prob, axis=1) with np.errstate(under="ignore"): # ignore underflow - log_resp = weighted_log_prob - np.reshape(log_prob_norm, (-1, 1)) + log_resp = weighted_log_prob - xp.reshape(log_prob_norm, (-1, 1)) return log_prob_norm, log_resp def _print_verbose_msg_init_beg(self, n_init): diff --git a/sklearn/mixture/_gaussian_mixture.py b/sklearn/mixture/_gaussian_mixture.py index 38e6c354d303b..f2315d385aa3a 100644 --- a/sklearn/mixture/_gaussian_mixture.py +++ b/sklearn/mixture/_gaussian_mixture.py @@ -175,13 +175,12 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): covariances : array, shape (n_components, n_features, n_features) The covariance matrix of the current components. """ - np, is_array_api = get_namespace(resp, X, nk) + xp, is_array_api = get_namespace(resp, X, nk) n_components, n_features = means.shape - covariances = np.empty((n_components, n_features, n_features)) + covariances = xp.empty((n_components, n_features, n_features)) for k in range(n_components): diff = X - means[k, :] covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k] - if is_array_api: for i in range(n_features): covariances[k, i, i] += reg_covar @@ -297,9 +296,9 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): The covariance matrix of the current components. The shape depends of the covariance_type. """ - np, _ = get_namespace(X, resp) - nk = np.sum(resp, axis=0) + 10 * np.finfo(resp.dtype).eps - means = resp.T @ X / np.reshape(nk, (-1, 1)) + xp, _ = get_namespace(X, resp) + nk = xp.sum(resp, axis=0) + 10 * xp.finfo(resp.dtype).eps + means = resp.T @ X / xp.reshape(nk, (-1, 1)) covariances = { "full": _estimate_gaussian_covariances_full, "tied": _estimate_gaussian_covariances_tied, @@ -333,23 +332,26 @@ def _compute_precision_cholesky(covariances, covariance_type): "or collapsed samples). Try to decrease the number of components, " "or increase reg_covar." ) - np, is_array_api = get_namespace(covariances) + xp, is_array_api = get_namespace(covariances) if is_array_api: - cholesky = np.linalg.cholesky - solve = np.linalg.solve + cholesky = xp.linalg.cholesky + solve = xp.linalg.solve else: cholesky = partial(scipy.linalg.cholesky, lower=True) solve = partial(scipy.linalg.solve_triangular, lower=True) if covariance_type == "full": n_components, n_features, _ = covariances.shape - precisions_chol = np.empty((n_components, n_features, n_features)) + precisions_chol = xp.empty((n_components, n_features, n_features)) for k in range(n_components): try: cov_chol = cholesky(covariances[k, :, :]) except linalg.LinAlgError: raise ValueError(estimate_precision_error_message) - precisions_chol[k, :, :] = solve(cov_chol, np.eye(n_features)).T + precisions_chol[k, :, :] = solve(cov_chol, xp.eye(n_features)).T + + if is_array_api: + precisions_chol[k, :, :] = xp.triu(precisions_chol[k, :, :]) elif covariance_type == "tied": _, n_features = covariances.shape @@ -389,20 +391,20 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): log_det_precision_chol : array-like of shape (n_components,) The determinant of the precision matrix for each component. """ - np, _ = get_namespace(matrix_chol) + xp, _ = get_namespace(matrix_chol) if covariance_type == "full": n_components, _, _ = matrix_chol.shape - matrix_col_reshape = np.reshape(matrix_chol, (n_components, -1)) - log_det_chol = np.sum(np.log(matrix_col_reshape[:, :: n_features + 1]), axis=1) + matrix_col_reshape = xp.reshape(matrix_chol, (n_components, -1)) + log_det_chol = xp.sum(xp.log(matrix_col_reshape[:, :: n_features + 1]), axis=1) elif covariance_type == "tied": - log_det_chol = np.sum(np.log(np.diag(matrix_chol))) + log_det_chol = xp.sum(xp.log(xp.diag(matrix_chol))) elif covariance_type == "diag": - log_det_chol = np.sum(np.log(matrix_chol), axis=1) + log_det_chol = xp.sum(xp.log(matrix_chol), axis=1) else: - log_det_chol = n_features * (np.log(matrix_chol)) + log_det_chol = n_features * (xp.log(matrix_chol)) return log_det_chol @@ -429,7 +431,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): ------- log_prob : array, shape (n_samples, n_components) """ - np, _ = get_namespace(X, means, precisions_chol) + xp, _ = get_namespace(X, means, precisions_chol) n_samples, n_features = X.shape n_components, _ = means.shape # The determinant of the precision matrix from the Cholesky decomposition @@ -439,12 +441,12 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): log_det = _compute_log_det_cholesky(precisions_chol, covariance_type, n_features) if covariance_type == "full": - log_prob = np.empty((n_samples, n_components)) + log_prob = xp.empty((n_samples, n_components)) for k in range(n_components): mu = means[k, :] prec_chol = precisions_chol[k, :, :] y = X @ prec_chol - mu @ prec_chol - log_prob[:, k] = np.sum(np.square(y), axis=1) + log_prob[:, k] = xp.sum(xp.square(y), axis=1) elif covariance_type == "tied": log_prob = np.empty((n_samples, n_components)) @@ -469,7 +471,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): ) # Since we are using the precision of the Cholesky decomposition, # `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol` - return -0.5 * (n_features * log(2 * np.pi) + log_prob) + log_det + return -0.5 * (n_features * log(2 * xp.pi) + log_prob) + log_det class GaussianMixture(BaseMixture): @@ -761,9 +763,9 @@ def _m_step(self, X, log_resp): the point of each sample in X. """ n_samples, _ = X.shape - np, _ = get_namespace(X, log_resp) + xp, _ = get_namespace(X, log_resp) self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters( - X, np.exp(log_resp), self.reg_covar, self.covariance_type + X, xp.exp(log_resp), self.reg_covar, self.covariance_type ) self.weights_ /= n_samples self.precisions_cholesky_ = _compute_precision_cholesky( @@ -776,8 +778,8 @@ def _estimate_log_prob(self, X): ) def _estimate_log_weights(self): - np, _ = get_namespace(self.weights_) - return np.log(self.weights_) + xp, _ = get_namespace(self.weights_) + return xp.log(self.weights_) def _compute_lower_bound(self, _, log_prob_norm): return log_prob_norm @@ -800,11 +802,12 @@ def _set_parameters(self, params): # Attributes computation _, n_features = self.means_.shape - if self.covariance_type == "full": - self.precisions_ = np.empty(self.precisions_cholesky_.shape) - for k, prec_chol in enumerate(self.precisions_cholesky_): - self.precisions_[k] = np.dot(prec_chol, prec_chol.T) + prec_cho = self.precisions_cholesky_ + xp, _ = get_namespace(prec_cho) + self.precisions_ = xp.empty(prec_cho.shape) + for k in range(prec_cho.shape[0]): + self.precisions_[k, :, :] = prec_cho[k, :, :] @ prec_cho[k, :, :].T elif self.covariance_type == "tied": self.precisions_ = np.dot( diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index e251b4dd521ea..553cfbb7a7e1e 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -12,6 +12,8 @@ from scipy import stats, linalg from sklearn.cluster import KMeans +from sklearn.base import clone +from sklearn._config import config_context from sklearn.covariance import EmpiricalCovariance from sklearn.datasets import make_spd_matrix from io import StringIO @@ -1322,3 +1324,31 @@ def test_gaussian_mixture_precisions_init_diag(): assert_allclose( gm_with_init.precisions_cholesky_, gm_without_init.precisions_cholesky_ ) + + +def test_gaussian_mixture_array_api(): + """Check that the array_api Array gives the same results as ndarrays""" + pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API") + xp = pytest.importorskip("numpy.array_api") + + rng = np.random.RandomState(0) + X = rng.rand(10, 4) + X_xp = xp.asarray(X) + + gm = GaussianMixture(n_components=2, random_state=0, init_params="random") + gm.fit(X) + + gm_xp = clone(gm) + with config_context(array_api_dispatch=True): + gm_xp.fit(X_xp) + + gm_attributes_array = { + key: value for key, value in vars(gm).items() if isinstance(value, np.ndarray) + } + for key in gm_attributes_array: + gm_xp_param = getattr(gm_xp, key) + assert hasattr(gm_xp_param, "__array_namespace__") + + assert_allclose( + gm_attributes_array[key], gm_xp_param, err_msg=f"{key} not the same" + ) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index e469f23104398..bc150af39d714 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -14,6 +14,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "array_api_dispatch": False, } # Not using as a context manager affects nothing @@ -26,6 +27,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "array_api_dispatch": False, } assert get_config()["assume_finite"] is False @@ -55,6 +57,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "array_api_dispatch": False, } # No positional arguments diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 23886e83fd818..1cde43976c69e 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1,43 +1,41 @@ """Tools to support array_api.""" import numpy -from scipy.special import logsumexp as sp_logsumexp +import scipy.special from .._config import get_config -from contextlib import nullcontext - - # There are more clever ways to wrap the API to ignore kwargs, but I am writing them out # explicitly for demonstration purposes class _ArrayAPIWrapper: def __init__(self, array_namespace): - self._array_namespace = array_namespace + self._namespace = array_namespace def __getattr__(self, name): - return getattr(self._array_namespace, name) + return getattr(self._namespace, name) - def errstate(self, *args, **kwargs): - # errstate not in `array_api` - return nullcontext() + def astype(self, x, dtype, *, copy=True, casting="unsafe"): + # support casting for NumPy + if self._namespace.__name__ == "numpy.aray_api": + x_np = x.astype(dtype, casting=casting, copy=copy) + return self._namespace.asarray(x_np) - def astype(self, x, dtype, copy=True, **kwargs): - # ignore parameters that is not supported by array-api - f = self._array_namespace.astype + f = self._namespace.astype return f(x, dtype, copy=copy) - def asarray(self, obj, dtype=None, device=None, copy=None, **kwargs): - f = self._array_namespace.asarray - return f(obj, dtype=dtype, device=device, copy=copy) + def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): + # support order in NumPy + if self._namespace.__name__ == "numpy.aray_api": + x_np = numpy.asarray(obj, dtype=dtype, copy=copy, order=order) + return self._namespace(x_np) - def array(self, obj, dtype=None, device=None, copy=True, **kwargs): - f = self._array_namespace.asarray + f = self._namespace.asarray return f(obj, dtype=dtype, device=device, copy=copy) - def asanyarray(self, obj, *args, **kwargs): - # no-op for now - return obj + def may_share_memory(self, a, b): + # support may_share_memory in NumPy + if self._namespace.__name__ == "numpy.aray_api": + return numpy.may_share_memory(a, b) - def may_share_memory(self, *args, **kwargs): - # The safe choice is to return True all the time + # The safe choice is to return True for all other array_api Arrays return True @@ -46,7 +44,7 @@ def __getattr__(self, name): return getattr(numpy, name) def astype(self, x, dtype, *args, **kwargs): - # astype is not defined in the top level numpy namespace + # astype is not defined in the top level NumPy namespace return x.astype(dtype, *args, **kwargs) @@ -81,42 +79,42 @@ def get_namespace(*xs): def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): - np, is_array_api = get_namespace(a) + xp, is_array_api = get_namespace(a) # Use SciPy if a is an ndarray if not is_array_api: - return sp_logsumexp( + return scipy.special.logsumexp( a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign ) if b is not None: - a, b = np.broadcast_arrays(a, b) - if np.any(b == 0): + a, b = xp.broadcast_arrays(a, b) + if xp.any(b == 0): a = a + 0.0 # promote to at least float - a[b == 0] = -np.inf + a[b == 0] = -xp.inf - a_max = np.max(a, axis=axis, keepdims=True) + a_max = xp.max(a, axis=axis, keepdims=True) if a_max.ndim > 0: - a_max[~np.isfinite(a_max)] = 0 - elif not np.isfinite(a_max): + a_max[~xp.isfinite(a_max)] = 0 + elif not xp.isfinite(a_max): a_max = 0 if b is not None: - b = np.asarray(b) - tmp = b * np.exp(a - a_max) + b = xp.asarray(b) + tmp = b * xp.exp(a - a_max) else: - tmp = np.exp(a - a_max) + tmp = xp.exp(a - a_max) # suppress warnings about log of zero - s = np.sum(tmp, axis=axis, keepdims=keepdims) + s = xp.sum(tmp, axis=axis, keepdims=keepdims) if return_sign: - sgn = np.sign(s) + sgn = xp.sign(s) s *= sgn # /= makes more sense but we need zero -> zero - out = np.log(s) + out = xp.log(s) if not keepdims: - a_max = np.squeeze(a_max, axis=axis) + a_max = xp.squeeze(a_max, axis=axis) out += a_max if return_sign: diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 57f99792b0251..ff2fa34807692 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -96,27 +96,26 @@ def _assert_all_finite( # validation is also imported in extmath from .extmath import _safe_accumulator_op - np, _ = get_namespace(X) + xp, _ = get_namespace(X) if _get_config()["assume_finite"]: return - X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated # safely to reduce dtype induced overflows. is_float = X.dtype.kind in "fc" - if is_float and (np.isfinite(_safe_accumulator_op(np.sum, X))): + if is_float and (xp.isfinite(_safe_accumulator_op(xp.sum, X))): pass elif is_float: if ( allow_nan - and np.isinf(X).any() + and xp.any(xp.isinf(X)) or not allow_nan - and not np.isfinite(X).all() + and not xp.all(xp.isfinite(X)) ): - if not allow_nan and np.isnan(X).any(): + if not allow_nan and xp.any(xp.isnan(X)): type_err = "NaN" else: msg_dtype = msg_dtype if msg_dtype is not None else X.dtype @@ -127,7 +126,7 @@ def _assert_all_finite( not allow_nan and estimator_name and input_name == "X" - and np.isnan(X).any() + and xp.any(xp.isnan(X)) ): # Improve the error message on how to handle missing values in # scikit-learn. @@ -144,8 +143,8 @@ def _assert_all_finite( raise ValueError(msg_err) # for object dtype data, we only check for NaNs (GH-13254) - elif X.dtype == np.dtype("object") and not allow_nan: - if np.any(_object_dtype_isnan(X)): + elif X.dtype == xp.dtype("object") and not allow_nan: + if xp.any(_object_dtype_isnan(X)): raise ValueError("Input contains NaN") @@ -700,7 +699,7 @@ def check_array( array_converted : object The converted and validated array. """ - if isinstance(array, numpy.matrix): + if isinstance(array, np.matrix): warnings.warn( "np.matrix usage is deprecated in 1.0 and will raise a TypeError " "in 1.2. Please convert to a numpy array with np.asarray. For " @@ -708,7 +707,7 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html", # noqa FutureWarning, ) - np, _ = get_namespace(array) + xp, _ = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -754,7 +753,7 @@ def check_array( if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. - dtype = np.float64 + dtype = xp.float64 else: dtype = None @@ -824,7 +823,7 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - array = np.asarray(array, order=order) + array = xp.asarray(array, order=order) if array.dtype.kind == "f": _assert_all_finite( array, @@ -833,9 +832,9 @@ def check_array( estimator_name=estimator_name, input_name=input_name, ) - array = np.astype(dtype, casting="unsafe", copy=False) + array = xp.astype(dtype, casting="unsafe", copy=False) else: - array = np.asarray(array, order=order, dtype=dtype) + array = xp.asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: raise ValueError( "Complex data not supported\n{}\n".format(array) @@ -876,7 +875,7 @@ def check_array( stacklevel=2, ) try: - array = array.astype(np.float64) + array = xp.astype(array, np.float64) except ValueError as e: raise ValueError( "Unable to convert array of bytes/strings " @@ -914,8 +913,8 @@ def check_array( % (n_features, array.shape, ensure_min_features, context) ) - if copy and np.may_share_memory(array, array_orig): - array = np.array(array, dtype=dtype, order=order) + if copy and xp.may_share_memory(array, array_orig): + array = xp.asarray(array, dtype=dtype, order=order, copy=True) return array From aeb1488c4a1b0ba505b6095abb00f306591495c4 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 15:35:19 -0500 Subject: [PATCH 06/11] DOC Adds period --- sklearn/mixture/tests/test_gaussian_mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index 553cfbb7a7e1e..83c7bf2f0bf21 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -1327,7 +1327,7 @@ def test_gaussian_mixture_precisions_init_diag(): def test_gaussian_mixture_array_api(): - """Check that the array_api Array gives the same results as ndarrays""" + """Check that the array_api Array gives the same results as ndarrays.""" pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API") xp = pytest.importorskip("numpy.array_api") From f574e76151309ae8eceb2154c6424a00923a31ad Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 15:41:46 -0500 Subject: [PATCH 07/11] FIX Fixes for copy --- sklearn/utils/_array_api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 1cde43976c69e..5275ddf860b3a 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -24,7 +24,10 @@ def astype(self, x, dtype, *, copy=True, casting="unsafe"): def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): # support order in NumPy if self._namespace.__name__ == "numpy.aray_api": - x_np = numpy.asarray(obj, dtype=dtype, copy=copy, order=order) + if copy: + x_np = numpy.array(obj, dtype=dtype, order=order, copy=True) + else: + x_np = numpy.asarray(obj, dtype=dtype, order=order) return self._namespace(x_np) f = self._namespace.asarray @@ -47,6 +50,12 @@ def astype(self, x, dtype, *args, **kwargs): # astype is not defined in the top level NumPy namespace return x.astype(dtype, *args, **kwargs) + def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): + if copy: + return numpy.array(obj, dtype=dtype, order=order, copy=True) + else: + return numpy.asarray(obj, dtype=dtype, order=order) + def get_namespace(*xs): # `xs` contains one or more arrays, or possibly Python scalars (accepting From b518f5bb1e469bbfeb01acc72422a58efd70be39 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 15:42:18 -0500 Subject: [PATCH 08/11] CLN Adds comment --- sklearn/utils/_array_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 5275ddf860b3a..7f2db931fa3a5 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -51,6 +51,7 @@ def astype(self, x, dtype, *args, **kwargs): return x.astype(dtype, *args, **kwargs) def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): + # copy is in the ArrayAPI spec but not in NumPy's asarray if copy: return numpy.array(obj, dtype=dtype, order=order, copy=True) else: From fb12c034d7827df115a4173153517ec9721089d1 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 15:55:22 -0500 Subject: [PATCH 09/11] CLN Fixes tests --- doc/whats_new/v1.1.rst | 5 ++++- sklearn/utils/_array_api.py | 4 ++-- sklearn/utils/validation.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 7110d93f72a68..3e4e4280ec93d 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -148,7 +148,7 @@ Changelog :user:`Sebastian Pujalte `. - |Enhancement| :func:`datasets.make_blobs` no longer copies data during the generation - process, therefore uses less memory. + process, therefore uses less memory. :pr:`22412` by :user:`Zhehao Liu `. - |Enhancement| :func:`datasets.load_diabetes` now accepts the parameter @@ -491,6 +491,9 @@ Changelog :mod:`sklearn.mixture` ...................... +- |Enhancement| Added ArrayAPI support for :class:`mixture.GaussianMixture` + for `init_params="random"` and `covariance_type="full"`. :pr:`xxxxx` by `Thomas Fan`_. + - |Fix| Fix a bug that correctly initialize `precisions_cholesky_` in :class:`mixture.GaussianMixture` when providing `precisions_init` by taking its square root. diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 7f2db931fa3a5..0224f949c5237 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -46,9 +46,9 @@ class _NumPyApiWrapper: def __getattr__(self, name): return getattr(numpy, name) - def astype(self, x, dtype, *args, **kwargs): + def astype(self, x, dtype, *, copy=True, casting="unsafe"): # astype is not defined in the top level NumPy namespace - return x.astype(dtype, *args, **kwargs) + return x.astype(dtype, copy=copy, casting=casting) def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): # copy is in the ArrayAPI spec but not in NumPy's asarray diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index ff2fa34807692..be298806d76b5 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -101,6 +101,7 @@ def _assert_all_finite( if _get_config()["assume_finite"]: return + X = xp.asarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated @@ -832,7 +833,7 @@ def check_array( estimator_name=estimator_name, input_name=input_name, ) - array = xp.astype(dtype, casting="unsafe", copy=False) + array = xp.astype(array, dtype, casting="unsafe", copy=False) else: array = xp.asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: From bf5975fead1e400b5aef1de327199e6f0ac4e4f2 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 11 Feb 2022 16:09:30 -0500 Subject: [PATCH 10/11] CLN Less diff --- sklearn/mixture/_gaussian_mixture.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/mixture/_gaussian_mixture.py b/sklearn/mixture/_gaussian_mixture.py index f2315d385aa3a..09b1c9a070be3 100644 --- a/sklearn/mixture/_gaussian_mixture.py +++ b/sklearn/mixture/_gaussian_mixture.py @@ -186,7 +186,6 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): covariances[k, i, i] += reg_covar else: covariances[k].flat[:: n_features + 1] += reg_covar - return covariances From 3c08e1dd52c5dbdca319650c80a839ca8030b4f5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 12 Feb 2022 20:23:10 -0500 Subject: [PATCH 11/11] CLN Fix spelling --- sklearn/utils/_array_api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 0224f949c5237..b290e32839df8 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -3,8 +3,7 @@ import scipy.special from .._config import get_config -# There are more clever ways to wrap the API to ignore kwargs, but I am writing them out -# explicitly for demonstration purposes + class _ArrayAPIWrapper: def __init__(self, array_namespace): self._namespace = array_namespace @@ -14,7 +13,7 @@ def __getattr__(self, name): def astype(self, x, dtype, *, copy=True, casting="unsafe"): # support casting for NumPy - if self._namespace.__name__ == "numpy.aray_api": + if self._namespace.__name__ == "numpy.array_api": x_np = x.astype(dtype, casting=casting, copy=copy) return self._namespace.asarray(x_np) @@ -23,7 +22,7 @@ def astype(self, x, dtype, *, copy=True, casting="unsafe"): def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): # support order in NumPy - if self._namespace.__name__ == "numpy.aray_api": + if self._namespace.__name__ == "numpy.array_api": if copy: x_np = numpy.array(obj, dtype=dtype, order=order, copy=True) else: @@ -35,7 +34,7 @@ def asarray(self, obj, *, dtype=None, device=None, copy=None, order=None): def may_share_memory(self, a, b): # support may_share_memory in NumPy - if self._namespace.__name__ == "numpy.aray_api": + if self._namespace.__name__ == "numpy.array_api": return numpy.may_share_memory(a, b) # The safe choice is to return True for all other array_api Arrays