diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index ee049937f5ce0..cc8ddddec5029 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -117,6 +117,8 @@ Estimators - :class:`preprocessing.MaxAbsScaler` - :class:`preprocessing.MinMaxScaler` - :class:`preprocessing.Normalizer` +- :class:`mixture.GaussianMixture` (with `init_params="random"` or + `init_params="random_from_data"` and `warm_start=False`) Meta-estimators --------------- diff --git a/doc/whats_new/upcoming_changes/array-api/30777.feature.rst b/doc/whats_new/upcoming_changes/array-api/30777.feature.rst new file mode 100644 index 0000000000000..ab3510a72e6d3 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30777.feature.rst @@ -0,0 +1,4 @@ +- :class:`sklearn.gaussian_mixture.GaussianMixture` with + `init_params="random"` or `init_params="random_from_data"` and + `warm_start=False` now supports Array API compatible inputs. + By :user:`Stefanie Senger ` and :user:`Loïc Estève ` diff --git a/sklearn/mixture/_base.py b/sklearn/mixture/_base.py index f66344a284753..a9627a0e74e7f 100644 --- a/sklearn/mixture/_base.py +++ b/sklearn/mixture/_base.py @@ -5,17 +5,24 @@ import warnings from abc import ABCMeta, abstractmethod +from contextlib import nullcontext from numbers import Integral, Real from time import time import numpy as np -from scipy.special import logsumexp from .. import cluster from ..base import BaseEstimator, DensityMixin, _fit_context from ..cluster import kmeans_plusplus from ..exceptions import ConvergenceWarning from ..utils import check_random_state +from ..utils._array_api import ( + _convert_to_numpy, + _is_numpy_namespace, + _logsumexp, + get_namespace, + get_namespace_and_device, +) from ..utils._param_validation import Interval, StrOptions from ..utils.validation import check_is_fitted, validate_data @@ -31,7 +38,6 @@ def _check_shape(param, param_shape, name): name : str """ - param = np.array(param) if param.shape != param_shape: raise ValueError( "The parameter '%s' should have the shape of %s, but got %s" @@ -86,7 +92,7 @@ def __init__( self.verbose_interval = verbose_interval @abstractmethod - def _check_parameters(self, X): + def _check_parameters(self, X, xp=None): """Check initial parameters of the derived class. Parameters @@ -95,7 +101,7 @@ def _check_parameters(self, X): """ pass - def _initialize_parameters(self, X, random_state): + def _initialize_parameters(self, X, random_state, xp=None): """Initialize the model parameters. Parameters @@ -106,6 +112,7 @@ def _initialize_parameters(self, X, random_state): A random number generator instance that controls the random seed used for the method chosen to initialize the parameters. """ + xp, _, device = get_namespace_and_device(X, xp=xp) n_samples, _ = X.shape if self.init_params == "kmeans": @@ -119,16 +126,25 @@ def _initialize_parameters(self, X, random_state): ) resp[np.arange(n_samples), label] = 1 elif self.init_params == "random": - resp = np.asarray( - random_state.uniform(size=(n_samples, self.n_components)), dtype=X.dtype + resp = xp.asarray( + random_state.uniform(size=(n_samples, self.n_components)), + dtype=X.dtype, + device=device, ) - resp /= resp.sum(axis=1)[:, np.newaxis] + resp /= xp.sum(resp, axis=1)[:, xp.newaxis] elif self.init_params == "random_from_data": - resp = np.zeros((n_samples, self.n_components), dtype=X.dtype) + resp = xp.zeros( + (n_samples, self.n_components), dtype=X.dtype, device=device + ) indices = random_state.choice( n_samples, size=self.n_components, replace=False ) - resp[indices, np.arange(self.n_components)] = 1 + # TODO: when array API supports __setitem__ with fancy indexing we + # can use the previous code: + # resp[indices, xp.arange(self.n_components)] = 1 + # Until then we use a for loop on one dimension. + for col, index in enumerate(indices): + resp[index, col] = 1 elif self.init_params == "k-means++": resp = np.zeros((n_samples, self.n_components), dtype=X.dtype) _, indices = kmeans_plusplus( @@ -210,20 +226,21 @@ def fit_predict(self, X, y=None): labels : array, shape (n_samples,) Component labels. """ - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_min_samples=2) + xp, _ = get_namespace(X) + X = validate_data(self, X, dtype=[xp.float64, xp.float32], ensure_min_samples=2) if X.shape[0] < self.n_components: raise ValueError( "Expected n_samples >= n_components " f"but got n_components = {self.n_components}, " f"n_samples = {X.shape[0]}" ) - self._check_parameters(X) + self._check_parameters(X, xp=xp) # if we enable warm_start, we will have a unique initialisation do_init = not (self.warm_start and hasattr(self, "converged_")) n_init = self.n_init if do_init else 1 - max_lower_bound = -np.inf + max_lower_bound = -xp.inf best_lower_bounds = [] self.converged_ = False @@ -234,9 +251,9 @@ def fit_predict(self, X, y=None): self._print_verbose_msg_init_beg(init) if do_init: - self._initialize_parameters(X, random_state) + self._initialize_parameters(X, random_state, xp=xp) - lower_bound = -np.inf if do_init else self.lower_bound_ + lower_bound = -xp.inf if do_init else self.lower_bound_ current_lower_bounds = [] if self.max_iter == 0: @@ -247,8 +264,8 @@ def fit_predict(self, X, y=None): for n_iter in range(1, self.max_iter + 1): prev_lower_bound = lower_bound - log_prob_norm, log_resp = self._e_step(X) - self._m_step(X, log_resp) + log_prob_norm, log_resp = self._e_step(X, xp=xp) + self._m_step(X, log_resp, xp=xp) lower_bound = self._compute_lower_bound(log_resp, log_prob_norm) current_lower_bounds.append(lower_bound) @@ -261,7 +278,7 @@ def fit_predict(self, X, y=None): self._print_verbose_msg_init_end(lower_bound, converged) - if lower_bound > max_lower_bound or max_lower_bound == -np.inf: + if lower_bound > max_lower_bound or max_lower_bound == -xp.inf: max_lower_bound = lower_bound best_params = self._get_parameters() best_n_iter = n_iter @@ -281,7 +298,7 @@ def fit_predict(self, X, y=None): ConvergenceWarning, ) - self._set_parameters(best_params) + self._set_parameters(best_params, xp=xp) self.n_iter_ = best_n_iter self.lower_bound_ = max_lower_bound self.lower_bounds_ = best_lower_bounds @@ -289,11 +306,11 @@ def fit_predict(self, X, y=None): # Always do a final e-step to guarantee that the labels returned by # fit_predict(X) are always consistent with fit(X).predict(X) # for any value of max_iter and tol (and any random_state). - _, log_resp = self._e_step(X) + _, log_resp = self._e_step(X, xp=xp) - return log_resp.argmax(axis=1) + return xp.argmax(log_resp, axis=1) - def _e_step(self, X): + def _e_step(self, X, xp=None): """E step. Parameters @@ -309,8 +326,9 @@ def _e_step(self, X): Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in X. """ - log_prob_norm, log_resp = self._estimate_log_prob_resp(X) - return np.mean(log_prob_norm), log_resp + xp, _ = get_namespace(X, xp=xp) + log_prob_norm, log_resp = self._estimate_log_prob_resp(X, xp=xp) + return xp.mean(log_prob_norm), log_resp @abstractmethod def _m_step(self, X, log_resp): @@ -351,7 +369,7 @@ def score_samples(self, X): check_is_fitted(self) X = validate_data(self, X, reset=False) - return logsumexp(self._estimate_weighted_log_prob(X), axis=1) + return _logsumexp(self._estimate_weighted_log_prob(X), axis=1) def score(self, X, y=None): """Compute the per-sample average log-likelihood of the given data X. @@ -370,7 +388,8 @@ def score(self, X, y=None): log_likelihood : float Log-likelihood of `X` under the Gaussian mixture model. """ - return self.score_samples(X).mean() + xp, _ = get_namespace(X) + return float(xp.mean(self.score_samples(X))) def predict(self, X): """Predict the labels for the data samples in X using trained model. @@ -387,8 +406,9 @@ def predict(self, X): Component labels. """ check_is_fitted(self) + xp, _ = get_namespace(X) X = validate_data(self, X, reset=False) - return self._estimate_weighted_log_prob(X).argmax(axis=1) + return xp.argmax(self._estimate_weighted_log_prob(X), axis=1) def predict_proba(self, X): """Evaluate the components' density for each sample. @@ -406,8 +426,9 @@ def predict_proba(self, X): """ check_is_fitted(self) X = validate_data(self, X, reset=False) - _, log_resp = self._estimate_log_prob_resp(X) - return np.exp(log_resp) + xp, _ = get_namespace(X) + _, log_resp = self._estimate_log_prob_resp(X, xp=xp) + return xp.exp(log_resp) def sample(self, n_samples=1): """Generate random samples from the fitted Gaussian distribution. @@ -426,6 +447,7 @@ def sample(self, n_samples=1): Component labels. """ check_is_fitted(self) + xp, _, device_ = get_namespace_and_device(self.means_) if n_samples < 1: raise ValueError( @@ -435,22 +457,30 @@ def sample(self, n_samples=1): _, n_features = self.means_.shape rng = check_random_state(self.random_state) - n_samples_comp = rng.multinomial(n_samples, self.weights_) + n_samples_comp = rng.multinomial( + n_samples, _convert_to_numpy(self.weights_, xp) + ) if self.covariance_type == "full": X = np.vstack( [ rng.multivariate_normal(mean, covariance, int(sample)) for (mean, covariance, sample) in zip( - self.means_, self.covariances_, n_samples_comp + _convert_to_numpy(self.means_, xp), + _convert_to_numpy(self.covariances_, xp), + n_samples_comp, ) ] ) elif self.covariance_type == "tied": X = np.vstack( [ - rng.multivariate_normal(mean, self.covariances_, int(sample)) - for (mean, sample) in zip(self.means_, n_samples_comp) + rng.multivariate_normal( + mean, _convert_to_numpy(self.covariances_, xp), int(sample) + ) + for (mean, sample) in zip( + _convert_to_numpy(self.means_, xp), n_samples_comp + ) ] ) else: @@ -460,18 +490,23 @@ def sample(self, n_samples=1): + rng.standard_normal(size=(sample, n_features)) * np.sqrt(covariance) for (mean, covariance, sample) in zip( - self.means_, self.covariances_, n_samples_comp + _convert_to_numpy(self.means_, xp), + _convert_to_numpy(self.covariances_, xp), + n_samples_comp, ) ] ) - y = np.concatenate( - [np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)] + y = xp.concat( + [ + xp.full(int(n_samples_comp[i]), i, dtype=xp.int64, device=device_) + for i in range(len(n_samples_comp)) + ] ) - return (X, y) + return xp.asarray(X, device=device_), y - def _estimate_weighted_log_prob(self, X): + def _estimate_weighted_log_prob(self, X, xp=None): """Estimate the weighted log-probabilities, log P(X | Z) + log weights. Parameters @@ -482,10 +517,10 @@ def _estimate_weighted_log_prob(self, X): ------- weighted_log_prob : array, shape (n_samples, n_component) """ - return self._estimate_log_prob(X) + self._estimate_log_weights() + return self._estimate_log_prob(X, xp=xp) + self._estimate_log_weights(xp=xp) @abstractmethod - def _estimate_log_weights(self): + def _estimate_log_weights(self, xp=None): """Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm. Returns @@ -495,7 +530,7 @@ def _estimate_log_weights(self): pass @abstractmethod - def _estimate_log_prob(self, X): + def _estimate_log_prob(self, X, xp=None): """Estimate the log-probabilities log P(X | Z). Compute the log-probabilities per each component for each sample. @@ -510,7 +545,7 @@ def _estimate_log_prob(self, X): """ pass - def _estimate_log_prob_resp(self, X): + def _estimate_log_prob_resp(self, X, xp=None): """Estimate log probabilities and responsibilities for each sample. Compute the log probabilities, weighted log probabilities per @@ -529,11 +564,17 @@ def _estimate_log_prob_resp(self, X): log_responsibilities : array, shape (n_samples, n_components) logarithm of the responsibilities """ - weighted_log_prob = self._estimate_weighted_log_prob(X) - log_prob_norm = logsumexp(weighted_log_prob, axis=1) - with np.errstate(under="ignore"): + xp, _ = get_namespace(X, xp=xp) + weighted_log_prob = self._estimate_weighted_log_prob(X, xp=xp) + log_prob_norm = _logsumexp(weighted_log_prob, axis=1, xp=xp) + + # There is no errstate equivalent for warning/error management in array API + context_manager = ( + np.errstate(under="ignore") if _is_numpy_namespace(xp) else nullcontext() + ) + with context_manager: # ignore underflow - log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] + log_resp = weighted_log_prob - log_prob_norm[:, xp.newaxis] return log_prob_norm, log_resp def _print_verbose_msg_init_beg(self, n_init): diff --git a/sklearn/mixture/_bayesian_mixture.py b/sklearn/mixture/_bayesian_mixture.py index 57220186faf61..76589c8214a99 100644 --- a/sklearn/mixture/_bayesian_mixture.py +++ b/sklearn/mixture/_bayesian_mixture.py @@ -410,7 +410,7 @@ def __init__( self.degrees_of_freedom_prior = degrees_of_freedom_prior self.covariance_prior = covariance_prior - def _check_parameters(self, X): + def _check_parameters(self, X, xp=None): """Check that the parameters are well defined. Parameters @@ -722,7 +722,7 @@ def _estimate_wishart_spherical(self, nk, xk, sk): # Contrary to the original bishop book, we normalize the covariances self.covariances_ /= self.degrees_of_freedom_ - def _m_step(self, X, log_resp): + def _m_step(self, X, log_resp, xp=None): """M step. Parameters @@ -742,7 +742,7 @@ def _m_step(self, X, log_resp): self._estimate_means(nk, xk) self._estimate_precisions(nk, xk, sk) - def _estimate_log_weights(self): + def _estimate_log_weights(self, xp=None): if self.weight_concentration_prior_type == "dirichlet_process": digamma_sum = digamma( self.weight_concentration_[0] + self.weight_concentration_[1] @@ -760,7 +760,7 @@ def _estimate_log_weights(self): np.sum(self.weight_concentration_) ) - def _estimate_log_prob(self, X): + def _estimate_log_prob(self, X, xp=None): _, n_features = X.shape # We remove `n_features * np.log(self.degrees_of_freedom_)` because # the precision matrix is normalized @@ -847,7 +847,7 @@ def _get_parameters(self): self.precisions_cholesky_, ) - def _set_parameters(self, params): + def _set_parameters(self, params, xp=None): ( self.weight_concentration_, self.mean_precision_, diff --git a/sklearn/mixture/_gaussian_mixture.py b/sklearn/mixture/_gaussian_mixture.py index c4bdd3a0d68c8..cd6523d1d2784 100644 --- a/sklearn/mixture/_gaussian_mixture.py +++ b/sklearn/mixture/_gaussian_mixture.py @@ -2,11 +2,19 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import math import numpy as np -from scipy import linalg +from .._config import get_config +from ..externals import array_api_extra as xpx from ..utils import check_array +from ..utils._array_api import ( + _cholesky, + _linalg_solve, + get_namespace, + get_namespace_and_device, +) from ..utils._param_validation import StrOptions from ..utils.extmath import row_norms from ._base import BaseMixture, _check_shape @@ -15,7 +23,7 @@ # Gaussian mixture shape checkers used by the GaussianMixture class -def _check_weights(weights, n_components): +def _check_weights(weights, n_components, xp=None): """Check the user provided 'weights'. Parameters @@ -30,28 +38,28 @@ def _check_weights(weights, n_components): ------- weights : array, shape (n_components,) """ - weights = check_array(weights, dtype=[np.float64, np.float32], ensure_2d=False) + weights = check_array(weights, dtype=[xp.float64, xp.float32], ensure_2d=False) _check_shape(weights, (n_components,), "weights") # check range - if any(np.less(weights, 0.0)) or any(np.greater(weights, 1.0)): + if any(xp.less(weights, 0.0)) or any(xp.greater(weights, 1.0)): raise ValueError( "The parameter 'weights' should be in the range " "[0, 1], but got max value %.5f, min value %.5f" - % (np.min(weights), np.max(weights)) + % (xp.min(weights), xp.max(weights)) ) # check normalization - atol = 1e-6 if weights.dtype == np.float32 else 1e-8 - if not np.allclose(np.abs(1.0 - np.sum(weights)), 0.0, atol=atol): + atol = 1e-6 if weights.dtype == xp.float32 else 1e-8 + if not np.allclose(float(xp.abs(1.0 - xp.sum(weights))), 0.0, atol=atol): raise ValueError( "The parameter 'weights' should be normalized, but got sum(weights) = %.5f" - % np.sum(weights) + % xp.sum(weights) ) return weights -def _check_means(means, n_components, n_features): +def _check_means(means, n_components, n_features, xp=None): """Validate the provided 'means'. Parameters @@ -69,34 +77,39 @@ def _check_means(means, n_components, n_features): ------- means : array, (n_components, n_features) """ - means = check_array(means, dtype=[np.float64, np.float32], ensure_2d=False) + xp, _ = get_namespace(means, xp=xp) + means = check_array(means, dtype=[xp.float64, xp.float32], ensure_2d=False) _check_shape(means, (n_components, n_features), "means") return means -def _check_precision_positivity(precision, covariance_type): +def _check_precision_positivity(precision, covariance_type, xp=None): """Check a precision vector is positive-definite.""" - if np.any(np.less_equal(precision, 0.0)): + xp, _ = get_namespace(precision, xp=xp) + if xp.any(xp.less_equal(precision, 0.0)): raise ValueError("'%s precision' should be positive" % covariance_type) -def _check_precision_matrix(precision, covariance_type): +def _check_precision_matrix(precision, covariance_type, xp=None): """Check a precision matrix is symmetric and positive-definite.""" + xp, _ = get_namespace(precision, xp=xp) if not ( - np.allclose(precision, precision.T) and np.all(linalg.eigvalsh(precision) > 0.0) + xp.all(xpx.isclose(precision, precision.T)) + and xp.all(xp.linalg.eigvalsh(precision) > 0.0) ): raise ValueError( "'%s precision' should be symmetric, positive-definite" % covariance_type ) -def _check_precisions_full(precisions, covariance_type): +def _check_precisions_full(precisions, covariance_type, xp=None): """Check the precision matrices are symmetric and positive-definite.""" - for prec in precisions: - _check_precision_matrix(prec, covariance_type) + xp, _ = get_namespace(precisions, xp=xp) + for i in range(precisions.shape[0]): + _check_precision_matrix(precisions[i, :, :], covariance_type, xp=xp) -def _check_precisions(precisions, covariance_type, n_components, n_features): +def _check_precisions(precisions, covariance_type, n_components, n_features, xp=None): """Validate user provided precisions. Parameters @@ -119,9 +132,10 @@ def _check_precisions(precisions, covariance_type, n_components, n_features): ------- precisions : array """ + xp, _ = get_namespace(precisions, xp=xp) precisions = check_array( precisions, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], ensure_2d=False, allow_nd=covariance_type == "full", ) @@ -142,7 +156,7 @@ def _check_precisions(precisions, covariance_type, n_components, n_features): "diag": _check_precision_positivity, "spherical": _check_precision_positivity, } - _check_precisions[covariance_type](precisions, covariance_type) + _check_precisions[covariance_type](precisions, covariance_type, xp=xp) return precisions @@ -150,7 +164,7 @@ def _check_precisions(precisions, covariance_type, n_components, n_features): # Gaussian mixture parameters estimators (used by the M-Step) -def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): +def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar, xp=None): """Estimate the full covariance matrices. Parameters @@ -170,16 +184,20 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar): covariances : array, shape (n_components, n_features, n_features) The covariance matrix of the current components. """ + xp, _, device_ = get_namespace_and_device(X, xp=xp) n_components, n_features = means.shape - covariances = np.empty((n_components, n_features, n_features), dtype=X.dtype) + covariances = xp.empty( + (n_components, n_features, n_features), device=device_, dtype=X.dtype + ) for k in range(n_components): - diff = X - means[k] - covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k] - covariances[k].flat[:: n_features + 1] += reg_covar + diff = X - means[k, :] + covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k] + covariances_flat = xp.reshape(covariances[k, :, :], (-1,)) + covariances_flat[:: n_features + 1] += reg_covar return covariances -def _estimate_gaussian_covariances_tied(resp, X, nk, means, reg_covar): +def _estimate_gaussian_covariances_tied(resp, X, nk, means, reg_covar, xp=None): """Estimate the tied covariance matrix. Parameters @@ -199,15 +217,17 @@ def _estimate_gaussian_covariances_tied(resp, X, nk, means, reg_covar): covariance : array, shape (n_features, n_features) The tied covariance matrix of the components. """ - avg_X2 = np.dot(X.T, X) - avg_means2 = np.dot(nk * means.T, means) + xp, _ = get_namespace(X, means, xp=xp) + avg_X2 = X.T @ X + avg_means2 = nk * means.T @ means covariance = avg_X2 - avg_means2 - covariance /= nk.sum() - covariance.flat[:: len(covariance) + 1] += reg_covar + covariance /= xp.sum(nk) + covariance_flat = xp.reshape(covariance, (-1,)) + covariance_flat[:: covariance.shape[0] + 1] += reg_covar return covariance -def _estimate_gaussian_covariances_diag(resp, X, nk, means, reg_covar): +def _estimate_gaussian_covariances_diag(resp, X, nk, means, reg_covar, xp=None): """Estimate the diagonal covariance vectors. Parameters @@ -227,12 +247,13 @@ def _estimate_gaussian_covariances_diag(resp, X, nk, means, reg_covar): covariances : array, shape (n_components, n_features) The covariance vector of the current components. """ - avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis] + xp, _ = get_namespace(X, xp=xp) + avg_X2 = (resp.T @ (X * X)) / nk[:, xp.newaxis] avg_means2 = means**2 return avg_X2 - avg_means2 + reg_covar -def _estimate_gaussian_covariances_spherical(resp, X, nk, means, reg_covar): +def _estimate_gaussian_covariances_spherical(resp, X, nk, means, reg_covar, xp=None): """Estimate the spherical variance values. Parameters @@ -252,10 +273,14 @@ def _estimate_gaussian_covariances_spherical(resp, X, nk, means, reg_covar): variances : array, shape (n_components,) The variance values of each components. """ - return _estimate_gaussian_covariances_diag(resp, X, nk, means, reg_covar).mean(1) + xp, _ = get_namespace(X) + return xp.mean( + _estimate_gaussian_covariances_diag(resp, X, nk, means, reg_covar, xp=xp), + axis=1, + ) -def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): +def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type, xp=None): """Estimate the Gaussian distribution parameters. Parameters @@ -284,18 +309,19 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): The covariance matrix of the current components. The shape depends of the covariance_type. """ - nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps - means = np.dot(resp.T, X) / nk[:, np.newaxis] + xp, _ = get_namespace(X, xp=xp) + nk = xp.sum(resp, axis=0) + 10 * xp.finfo(resp.dtype).eps + means = (resp.T @ X) / nk[:, xp.newaxis] covariances = { "full": _estimate_gaussian_covariances_full, "tied": _estimate_gaussian_covariances_tied, "diag": _estimate_gaussian_covariances_diag, "spherical": _estimate_gaussian_covariances_spherical, - }[covariance_type](resp, X, nk, means, reg_covar) + }[covariance_type](resp, X, nk, means, reg_covar, xp=xp) return nk, means, covariances -def _compute_precision_cholesky(covariances, covariance_type): +def _compute_precision_cholesky(covariances, covariance_type, xp=None): """Compute the Cholesky decomposition of the precisions. Parameters @@ -313,6 +339,8 @@ def _compute_precision_cholesky(covariances, covariance_type): The cholesky decomposition of sample precisions of the current components. The shape depends of the covariance_type. """ + xp, _, device_ = get_namespace_and_device(covariances, xp=xp) + estimate_precision_error_message = ( "Fitting the mixture model failed because some components have " "ill-defined empirical covariance (for instance caused by singleton " @@ -320,7 +348,7 @@ def _compute_precision_cholesky(covariances, covariance_type): "increase reg_covar, or scale the input data." ) dtype = covariances.dtype - if dtype == np.float32: + if dtype == xp.float32: estimate_precision_error_message += ( " The numerical accuracy can also be improved by passing float64" " data instead of float32." @@ -328,37 +356,43 @@ def _compute_precision_cholesky(covariances, covariance_type): if covariance_type == "full": n_components, n_features, _ = covariances.shape - precisions_chol = np.empty((n_components, n_features, n_features), dtype=dtype) - for k, covariance in enumerate(covariances): + precisions_chol = xp.empty( + (n_components, n_features, n_features), device=device_, dtype=dtype + ) + for k in range(covariances.shape[0]): + covariance = covariances[k, :, :] try: - cov_chol = linalg.cholesky(covariance, lower=True) - except linalg.LinAlgError: + cov_chol = _cholesky(covariance, xp) + # catch only numpy exceptions, b/c exceptions aren't part of array api spec + except np.linalg.LinAlgError: raise ValueError(estimate_precision_error_message) - precisions_chol[k] = linalg.solve_triangular( - cov_chol, np.eye(n_features, dtype=dtype), lower=True + precisions_chol[k, :, :] = _linalg_solve( + cov_chol, xp.eye(n_features, dtype=dtype, device=device_), xp ).T elif covariance_type == "tied": _, n_features = covariances.shape try: - cov_chol = linalg.cholesky(covariances, lower=True) - except linalg.LinAlgError: + cov_chol = _cholesky(covariances, xp) + # catch only numpy exceptions, since exceptions are not part of array api spec + except np.linalg.LinAlgError: raise ValueError(estimate_precision_error_message) - precisions_chol = linalg.solve_triangular( - cov_chol, np.eye(n_features, dtype=dtype), lower=True + precisions_chol = _linalg_solve( + cov_chol, xp.eye(n_features, dtype=dtype, device=device_), xp ).T else: - if np.any(np.less_equal(covariances, 0.0)): + if xp.any(covariances <= 0.0): raise ValueError(estimate_precision_error_message) - precisions_chol = 1.0 / np.sqrt(covariances) + precisions_chol = 1.0 / xp.sqrt(covariances) return precisions_chol -def _flipudlr(array): +def _flipudlr(array, xp=None): """Reverse the rows and columns of an array.""" - return np.flipud(np.fliplr(array)) + xp, _ = get_namespace(array, xp=xp) + return xp.flip(xp.flip(array, axis=1), axis=0) -def _compute_precision_cholesky_from_precisions(precisions, covariance_type): +def _compute_precision_cholesky_from_precisions(precisions, covariance_type, xp=None): r"""Compute the Cholesky decomposition of precisions using precisions themselves. As implemented in :func:`_compute_precision_cholesky`, the `precisions_cholesky_` is @@ -393,24 +427,26 @@ def _compute_precision_cholesky_from_precisions(precisions, covariance_type): components. The shape depends on the covariance_type. """ if covariance_type == "full": - precisions_cholesky = np.array( + precisions_cholesky = xp.stack( [ - _flipudlr(linalg.cholesky(_flipudlr(precision), lower=True)) - for precision in precisions + _flipudlr( + _cholesky(_flipudlr(precisions[i, :, :], xp=xp), xp=xp), xp=xp + ) + for i in range(precisions.shape[0]) ] ) elif covariance_type == "tied": precisions_cholesky = _flipudlr( - linalg.cholesky(_flipudlr(precisions), lower=True) + _cholesky(_flipudlr(precisions, xp=xp), xp=xp), xp=xp ) else: - precisions_cholesky = np.sqrt(precisions) + precisions_cholesky = xp.sqrt(precisions) return precisions_cholesky ############################################################################### # Gaussian mixture probability estimators -def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): +def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features, xp=None): """Compute the log-det of the cholesky decomposition of matrices. Parameters @@ -432,25 +468,27 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): log_det_precision_chol : array-like of shape (n_components,) The determinant of the precision matrix for each component. """ + xp, _ = get_namespace(matrix_chol, xp=xp) if covariance_type == "full": n_components, _, _ = matrix_chol.shape - log_det_chol = np.sum( - np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), axis=1 + log_det_chol = xp.sum( + xp.log(xp.reshape(matrix_chol, (n_components, -1))[:, :: n_features + 1]), + axis=1, ) elif covariance_type == "tied": - log_det_chol = np.sum(np.log(np.diag(matrix_chol))) + log_det_chol = xp.sum(xp.log(xp.linalg.diagonal(matrix_chol))) elif covariance_type == "diag": - log_det_chol = np.sum(np.log(matrix_chol), axis=1) + log_det_chol = xp.sum(xp.log(matrix_chol), axis=1) else: - log_det_chol = n_features * np.log(matrix_chol) + log_det_chol = n_features * xp.log(matrix_chol) return log_det_chol -def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): +def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type, xp=None): """Estimate the log Gaussian probability. Parameters @@ -472,6 +510,7 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): ------- log_prob : array, shape (n_samples, n_components) """ + xp, _, device_ = get_namespace_and_device(X, means, precisions_chol, xp=xp) n_samples, n_features = X.shape n_components, _ = means.shape # The determinant of the precision matrix from the Cholesky decomposition @@ -481,35 +520,38 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): log_det = _compute_log_det_cholesky(precisions_chol, covariance_type, n_features) if covariance_type == "full": - log_prob = np.empty((n_samples, n_components), dtype=X.dtype) - for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): - y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) - log_prob[:, k] = np.sum(np.square(y), axis=1) + log_prob = xp.empty((n_samples, n_components), dtype=X.dtype, device=device_) + for k in range(means.shape[0]): + mu = means[k, :] + prec_chol = precisions_chol[k, :, :] + y = (X @ prec_chol) - (mu @ prec_chol) + log_prob[:, k] = xp.sum(xp.square(y), axis=1) elif covariance_type == "tied": - log_prob = np.empty((n_samples, n_components), dtype=X.dtype) - for k, mu in enumerate(means): - y = np.dot(X, precisions_chol) - np.dot(mu, precisions_chol) - log_prob[:, k] = np.sum(np.square(y), axis=1) + log_prob = xp.empty((n_samples, n_components), dtype=X.dtype, device=device_) + for k in range(means.shape[0]): + mu = means[k, :] + y = (X @ precisions_chol) - (mu @ precisions_chol) + log_prob[:, k] = xp.sum(xp.square(y), axis=1) elif covariance_type == "diag": precisions = precisions_chol**2 log_prob = ( - np.sum((means**2 * precisions), 1) - - 2.0 * np.dot(X, (means * precisions).T) - + np.dot(X**2, precisions.T) + xp.sum((means**2 * precisions), axis=1) + - 2.0 * (X @ (means * precisions).T) + + (X**2 @ precisions.T) ) elif covariance_type == "spherical": precisions = precisions_chol**2 log_prob = ( - np.sum(means**2, 1) * precisions - - 2 * np.dot(X, means.T * precisions) - + np.outer(row_norms(X, squared=True), precisions) + xp.sum(means**2, axis=1) * precisions + - 2 * (X @ means.T * precisions) + + xp.linalg.outer(row_norms(X, squared=True), precisions) ) # Since we are using the precision of the Cholesky decomposition, # `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol` - return -0.5 * (n_features * np.log(2 * np.pi).astype(X.dtype) + log_prob) + log_det + return -0.5 * (n_features * math.log(2 * math.pi) + log_prob) + log_det class GaussianMixture(BaseMixture): @@ -752,16 +794,18 @@ def __init__( self.means_init = means_init self.precisions_init = precisions_init - def _check_parameters(self, X): + def _check_parameters(self, X, xp=None): """Check the Gaussian mixture parameters are well defined.""" _, n_features = X.shape if self.weights_init is not None: - self.weights_init = _check_weights(self.weights_init, self.n_components) + self.weights_init = _check_weights( + self.weights_init, self.n_components, xp=xp + ) if self.means_init is not None: self.means_init = _check_means( - self.means_init, self.n_components, n_features + self.means_init, self.n_components, n_features, xp=xp ) if self.precisions_init is not None: @@ -770,9 +814,23 @@ def _check_parameters(self, X): self.covariance_type, self.n_components, n_features, + xp=xp, ) - def _initialize_parameters(self, X, random_state): + allowed_init_params = ["random", "random_from_data"] + if ( + get_config()["array_api_dispatch"] + and self.init_params not in allowed_init_params + ): + raise NotImplementedError( + f"Allowed `init_params` are {allowed_init_params} if " + f"'array_api_dispatch' is enabled. You passed " + f"init_params={self.init_params!r}, which are not implemented to work " + "with 'array_api_dispatch' enabled. Please disable " + f"'array_api_dispatch' to use init_params={self.init_params!r}." + ) + + def _initialize_parameters(self, X, random_state, xp=None): # If all the initial parameters are all provided, then there is no need to run # the initialization. compute_resp = ( @@ -781,11 +839,11 @@ def _initialize_parameters(self, X, random_state): or self.precisions_init is None ) if compute_resp: - super()._initialize_parameters(X, random_state) + super()._initialize_parameters(X, random_state, xp=xp) else: - self._initialize(X, None) + self._initialize(X, None, xp=xp) - def _initialize(self, X, resp): + def _initialize(self, X, resp, xp=None): """Initialization of the Gaussian mixture parameters. Parameters @@ -794,29 +852,32 @@ def _initialize(self, X, resp): resp : array-like of shape (n_samples, n_components) """ + xp, _, device_ = get_namespace_and_device(X, xp=xp) n_samples, _ = X.shape weights, means, covariances = None, None, None if resp is not None: weights, means, covariances = _estimate_gaussian_parameters( - X, resp, self.reg_covar, self.covariance_type + X, resp, self.reg_covar, self.covariance_type, xp=xp ) if self.weights_init is None: weights /= n_samples self.weights_ = weights if self.weights_init is None else self.weights_init + self.weights_ = xp.asarray(self.weights_, device=device_) + self.means_ = means if self.means_init is None else self.means_init if self.precisions_init is None: self.covariances_ = covariances self.precisions_cholesky_ = _compute_precision_cholesky( - covariances, self.covariance_type + covariances, self.covariance_type, xp=xp ) else: self.precisions_cholesky_ = _compute_precision_cholesky_from_precisions( - self.precisions_init, self.covariance_type + self.precisions_init, self.covariance_type, xp=xp ) - def _m_step(self, X, log_resp): + def _m_step(self, X, log_resp, xp=None): """M step. Parameters @@ -827,21 +888,23 @@ def _m_step(self, X, log_resp): Logarithm of the posterior probabilities (or responsibilities) of the point of each sample in X. """ + xp, _ = get_namespace(X, log_resp, xp=xp) self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters( - X, np.exp(log_resp), self.reg_covar, self.covariance_type + X, xp.exp(log_resp), self.reg_covar, self.covariance_type, xp=xp ) - self.weights_ /= self.weights_.sum() + self.weights_ /= xp.sum(self.weights_) self.precisions_cholesky_ = _compute_precision_cholesky( - self.covariances_, self.covariance_type + self.covariances_, self.covariance_type, xp=xp ) - def _estimate_log_prob(self, X): + def _estimate_log_prob(self, X, xp=None): return _estimate_log_gaussian_prob( - X, self.means_, self.precisions_cholesky_, self.covariance_type + X, self.means_, self.precisions_cholesky_, self.covariance_type, xp=xp ) - def _estimate_log_weights(self): - return np.log(self.weights_) + def _estimate_log_weights(self, xp=None): + xp, _ = get_namespace(self.weights_, xp=xp) + return xp.log(self.weights_) def _compute_lower_bound(self, _, log_prob_norm): return log_prob_norm @@ -854,7 +917,8 @@ def _get_parameters(self): self.precisions_cholesky_, ) - def _set_parameters(self, params): + def _set_parameters(self, params, xp=None): + xp, _, device_ = get_namespace_and_device(params, xp=xp) ( self.weights_, self.means_, @@ -867,14 +931,14 @@ def _set_parameters(self, params): dtype = self.precisions_cholesky_.dtype if self.covariance_type == "full": - self.precisions_ = np.empty_like(self.precisions_cholesky_) - for k, prec_chol in enumerate(self.precisions_cholesky_): - self.precisions_[k] = np.dot(prec_chol, prec_chol.T) + self.precisions_ = xp.empty_like(self.precisions_cholesky_, device=device_) + for k in range(self.precisions_cholesky_.shape[0]): + prec_chol = self.precisions_cholesky_[k, :, :] + self.precisions_[k, :, :] = prec_chol @ prec_chol.T elif self.covariance_type == "tied": - self.precisions_ = np.dot( - self.precisions_cholesky_, self.precisions_cholesky_.T - ) + self.precisions_ = self.precisions_cholesky_ @ self.precisions_cholesky_.T + else: self.precisions_ = self.precisions_cholesky_**2 @@ -911,7 +975,7 @@ def bic(self, X): bic : float The lower the better. """ - return -2 * self.score(X) * X.shape[0] + self._n_parameters() * np.log( + return -2 * self.score(X) * X.shape[0] + self._n_parameters() * math.log( X.shape[0] ) diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index 488a2ab147e83..794a4dfc070ce 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -17,6 +17,7 @@ from sklearn.cluster import KMeans from sklearn.covariance import EmpiricalCovariance from sklearn.datasets import make_spd_matrix +from sklearn.datasets._samples_generator import make_blobs from sklearn.exceptions import ConvergenceWarning, NotFittedError from sklearn.metrics.cluster import adjusted_rand_score from sklearn.mixture import GaussianMixture @@ -29,11 +30,20 @@ _estimate_gaussian_covariances_tied, _estimate_gaussian_parameters, ) +from sklearn.utils._array_api import ( + _convert_to_numpy, + _get_namespace_device_dtype_ids, + device, + get_namespace, + yield_namespace_device_dtype_combinations, +) from sklearn.utils._testing import ( + _array_api_for_tests, assert_allclose, assert_almost_equal, assert_array_almost_equal, assert_array_equal, + skip_if_array_api_compat_not_configured, ) from sklearn.utils.extmath import fast_logdet @@ -1471,3 +1481,161 @@ def test_gaussian_mixture_all_init_does_not_estimate_gaussian_parameters( # The initial gaussian parameters are not estimated. They are estimated for every # m_step. assert mock.call_count == gm.n_iter_ + + +@pytest.mark.parametrize("init_params", ["random", "random_from_data"]) +@pytest.mark.parametrize("covariance_type", ["full", "tied", "diag", "spherical"]) +@pytest.mark.parametrize( + "array_namespace, device_, dtype", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +@pytest.mark.parametrize("use_gmm_array_constructor_arguments", [False, True]) +def test_gaussian_mixture_array_api_compliance( + init_params, + covariance_type, + array_namespace, + device_, + dtype, + use_gmm_array_constructor_arguments, +): + """Test that array api works in GaussianMixture.fit().""" + xp = _array_api_for_tests(array_namespace, device_) + + rng = np.random.RandomState(0) + rand_data = RandomData(rng) + X = rand_data.X[covariance_type] + X = X.astype(dtype) + + if use_gmm_array_constructor_arguments: + additional_kwargs = { + "means_init": rand_data.means.astype(dtype), + "precisions_init": rand_data.precisions[covariance_type].astype(dtype), + "weights_init": rand_data.weights.astype(dtype), + } + else: + additional_kwargs = {} + + gmm = GaussianMixture( + n_components=rand_data.n_components, + covariance_type=covariance_type, + random_state=0, + init_params=init_params, + **additional_kwargs, + ) + gmm.fit(X) + + X_xp = xp.asarray(X, device=device_) + + with sklearn.config_context(array_api_dispatch=True): + gmm_xp = sklearn.clone(gmm) + for param_name, param_value in additional_kwargs.items(): + arg_xp = xp.asarray(param_value, device=device_) + setattr(gmm_xp, param_name, arg_xp) + + gmm_xp.fit(X_xp) + + assert get_namespace(gmm_xp.means_)[0] == xp + assert get_namespace(gmm_xp.covariances_)[0] == xp + assert device(gmm_xp.means_) == device(X_xp) + assert device(gmm_xp.covariances_) == device(X_xp) + + predict_xp = gmm_xp.predict(X_xp) + predict_proba_xp = gmm_xp.predict_proba(X_xp) + score_samples_xp = gmm_xp.score_samples(X_xp) + score_xp = gmm_xp.score(X_xp) + aic_xp = gmm_xp.aic(X_xp) + bic_xp = gmm_xp.bic(X_xp) + sample_X_xp, sample_y_xp = gmm_xp.sample(10) + + results = [ + predict_xp, + predict_proba_xp, + score_samples_xp, + sample_X_xp, + sample_y_xp, + ] + for result in results: + assert get_namespace(result)[0] == xp + assert device(result) == device(X_xp) + + for score in [score_xp, aic_xp, bic_xp]: + assert isinstance(score, float) + + # Define specific rtol to make tests pass + default_rtol = 1e-4 if dtype == "float32" else 1e-7 + increased_atol = 5e-4 if dtype == "float32" else 0 + increased_rtol = 1e-3 if dtype == "float32" else 1e-7 + + # Check fitted attributes + assert_allclose(gmm.means_, _convert_to_numpy(gmm_xp.means_, xp=xp)) + assert_allclose(gmm.weights_, _convert_to_numpy(gmm_xp.weights_, xp=xp)) + assert_allclose( + gmm.covariances_, + _convert_to_numpy(gmm_xp.covariances_, xp=xp), + atol=increased_atol, + rtol=increased_rtol, + ) + assert_allclose( + gmm.precisions_cholesky_, + _convert_to_numpy(gmm_xp.precisions_cholesky_, xp=xp), + atol=increased_atol, + rtol=increased_rtol, + ) + assert_allclose( + gmm.precisions_, + _convert_to_numpy(gmm_xp.precisions_, xp=xp), + atol=increased_atol, + rtol=increased_rtol, + ) + + # Check methods + assert ( + adjusted_rand_score(gmm.predict(X), _convert_to_numpy(predict_xp, xp=xp)) > 0.95 + ) + assert_allclose( + gmm.predict_proba(X), + _convert_to_numpy(predict_proba_xp, xp=xp), + rtol=increased_rtol, + atol=increased_atol, + ) + assert_allclose( + gmm.score_samples(X), + _convert_to_numpy(score_samples_xp, xp=xp), + rtol=increased_rtol, + ) + # comparing Python float so need explicit rtol when X has dtype float32 + assert_allclose(gmm.score(X), score_xp, rtol=default_rtol) + assert_allclose(gmm.aic(X), aic_xp, rtol=default_rtol) + assert_allclose(gmm.bic(X), bic_xp, rtol=default_rtol) + sample_X, sample_y = gmm.sample(10) + # generated samples are float64 so need explicit rtol when X has dtype float32 + assert_allclose(sample_X, _convert_to_numpy(sample_X_xp, xp=xp), rtol=default_rtol) + assert_allclose(sample_y, _convert_to_numpy(sample_y_xp, xp=xp)) + + +@skip_if_array_api_compat_not_configured +@pytest.mark.parametrize("init_params", ["kmeans", "k-means++"]) +@pytest.mark.parametrize( + "array_namespace, device_, dtype", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +def test_gaussian_mixture_raises_where_array_api_not_implemented( + init_params, array_namespace, device_, dtype +): + X, _ = make_blobs( + n_samples=100, + n_features=2, + centers=3, + ) + gmm = GaussianMixture( + n_components=3, covariance_type="diag", init_params=init_params + ) + + with sklearn.config_context(array_api_dispatch=True): + with pytest.raises( + NotImplementedError, + match="Allowed `init_params`.+if 'array_api_dispatch' is enabled", + ): + gmm.fit(X) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index e2bee3530f26f..cbaaa9f5168a9 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1028,3 +1028,52 @@ def _tolist(array, xp=None): return array.tolist() array_np = _convert_to_numpy(array, xp=xp) return [element.item() for element in array_np] + + +def _logsumexp(array, axis=None, xp=None): + # TODO replace by scipy.special.logsumexp when + # https://github.com/scipy/scipy/pull/22683 is part of a release. + # The following code is strongly inspired and simplified from + # scipy.special._logsumexp.logsumexp + xp, _, device = get_namespace_and_device(array, xp=xp) + axis = tuple(range(array.ndim)) if axis is None else axis + + supported_dtypes = supported_float_dtypes(xp) + if array.dtype not in supported_dtypes: + array = xp.asarray(array, dtype=supported_dtypes[0]) + + array_max = xp.max(array, axis=axis, keepdims=True) + index_max = array == array_max + + array = xp.asarray(array, copy=True) + array[index_max] = -xp.inf + i_max_dt = xp.astype(index_max, array.dtype) + m = xp.sum(i_max_dt, axis=axis, keepdims=True, dtype=array.dtype) + # Specifying device explicitly is the fix for https://github.com/scipy/scipy/issues/22680 + shift = xp.where( + xp.isfinite(array_max), + array_max, + xp.asarray(0, dtype=array_max.dtype, device=device), + ) + exp = xp.exp(array - shift) + s = xp.sum(exp, axis=axis, keepdims=True, dtype=exp.dtype) + s = xp.where(s == 0, s, s / m) + out = xp.log1p(s) + xp.log(m) + array_max + out = xp.squeeze(out, axis=axis) + out = out[()] if out.ndim == 0 else out + + return out + + +def _cholesky(covariance, xp): + if _is_numpy_namespace(xp): + return scipy.linalg.cholesky(covariance, lower=True) + else: + return xp.linalg.cholesky(covariance) + + +def _linalg_solve(cov_chol, eye_matrix, xp): + if _is_numpy_namespace(xp): + return scipy.linalg.solve_triangular(cov_chol, eye_matrix, lower=True) + else: + return xp.linalg.solve(cov_chol, eye_matrix) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 4d74b0bf8db43..5d35d86432f3c 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -3,6 +3,7 @@ import numpy import pytest +import scipy from numpy.testing import assert_allclose from sklearn._config import config_context @@ -18,6 +19,7 @@ _get_namespace_device_dtype_ids, _is_numpy_namespace, _isin, + _logsumexp, _max_precision_float_dtype, _median, _nanmax, @@ -634,3 +636,58 @@ def test_median(namespace, device, dtype_name, axis): assert get_namespace(result_xp)[0] == xp assert result_xp.device == X_xp.device assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) + + +@pytest.mark.parametrize( + "array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations() +) +@pytest.mark.parametrize("axis", [0, 1, None]) +def test_logsumexp_like_scipy_logsumexp(array_namespace, device_, dtype_name, axis): + xp = _array_api_for_tests(array_namespace, device_) + array_np = numpy.asarray( + [ + [0, 3, 1000], + [2, -1, 1000], + [-10, 0, 0], + [-50, 8, -numpy.inf], + [4, 0, 5], + ], + dtype=dtype_name, + ) + array_xp = xp.asarray(array_np, device=device_) + + res_np = scipy.special.logsumexp(array_np, axis=axis) + + rtol = 1e-6 if "float32" in str(dtype_name) else 1e-12 + + # if torch on CPU or array api strict on default device + # check that _logsumexp works when array API dispatch is disabled + if (array_namespace == "torch" and device_ == "cpu") or ( + array_namespace == "array_api_strict" and "CPU" in str(device_) + ): + assert_allclose(_logsumexp(array_xp, axis=axis), res_np, rtol=rtol) + + with config_context(array_api_dispatch=True): + res_xp = _logsumexp(array_xp, axis=axis) + res_xp = _convert_to_numpy(res_xp, xp) + assert_allclose(res_np, res_xp, rtol=rtol) + + # Test with NaNs and +np.inf + array_np_2 = numpy.asarray( + [ + [0, numpy.nan, 1000], + [2, -1, 1000], + [numpy.inf, 0, 0], + [-50, 8, -numpy.inf], + [4, 0, 5], + ], + dtype=dtype_name, + ) + array_xp_2 = xp.asarray(array_np_2, device=device_) + + res_np_2 = scipy.special.logsumexp(array_np_2, axis=axis) + + with config_context(array_api_dispatch=True): + res_xp_2 = _logsumexp(array_xp_2, axis=axis) + res_xp_2 = _convert_to_numpy(res_xp_2, xp) + assert_allclose(res_np_2, res_xp_2, rtol=rtol)