diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 39f8e405ebf7c..b88e54b46a4bf 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -393,6 +393,15 @@ Changelog beginning which speeds up fitting. :pr:`22206` by :user:`Christian Lorentzen `. +- |Enhancement| :class:`~linear_model.LogisticRegression` is faster for + ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for + multiclass problems thanks to the new private loss function module. In the multiclass + case, the memory consumption has also been reduced for these solvers as the target is + now label encoded (mapped to integers) instead of label binarized (one-hot encoded). + The more classes, the larger the benefit. + :pr:`21808`, :pr:`20567` and :pr:`21814` by + :user:`Christian Lorentzen `. + - |Enhancement| Rename parameter `base_estimator` to `estimator` in :class:`linear_model.RANSACRegressor` to improve readability and consistency. `base_estimator` is deprecated and will be removed in 1.3. diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py new file mode 100644 index 0000000000000..93a7684aea5b6 --- /dev/null +++ b/sklearn/linear_model/_linear_loss.py @@ -0,0 +1,411 @@ +""" +Loss functions for linear models with raw_prediction = X @ coef +""" +import numpy as np +from scipy import sparse +from ..utils.extmath import squared_norm + + +class LinearModelLoss: + """General class for loss functions with raw_prediction = X @ coef + intercept. + + The loss is the sum of per sample losses and includes a term for L2 + regularization:: + + loss = sum_i s_i loss(y_i, X_i @ coef + intercept) + + 1/2 * l2_reg_strength * ||coef||_2^2 + + with sample weights s_i=1 if sample_weight=None. + + Gradient and hessian, for simplicity without intercept, are:: + + gradient = X.T @ loss.gradient + l2_reg_strength * coef + hessian = X.T @ diag(loss.hessian) @ X + l2_reg_strength * identity + + Conventions: + if fit_intercept: + n_dof = n_features + 1 + else: + n_dof = n_features + + if base_loss.is_multiclass: + coef.shape = (n_classes, n_dof) or ravelled (n_classes * n_dof,) + else: + coef.shape = (n_dof,) + + The intercept term is at the end of the coef array: + if base_loss.is_multiclass: + if coef.shape (n_classes, n_dof): + intercept = coef[:, -1] + if coef.shape (n_classes * n_dof,) + intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] + intercept.shape = (n_classes,) + else: + intercept = coef[-1] + + Note: If coef has shape (n_classes * n_dof,), the 2d-array can be reconstructed as + + coef.reshape((n_classes, -1), order="F") + + The option order="F" makes coef[:, i] contiguous. This, in turn, makes the + coefficients without intercept, coef[:, :-1], contiguous and speeds up + matrix-vector computations. + + Note: If the average loss per sample is wanted instead of the sum of the loss per + sample, one can simply use a rescaled sample_weight such that + sum(sample_weight) = 1. + + Parameters + ---------- + base_loss : instance of class BaseLoss from sklearn._loss. + fit_intercept : bool + """ + + def __init__(self, base_loss, fit_intercept): + self.base_loss = base_loss + self.fit_intercept = fit_intercept + + def _w_intercept_raw(self, coef, X): + """Helper function to get coefficients, intercept and raw_prediction. + + Parameters + ---------- + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) + Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + Returns + ------- + weights : ndarray of shape (n_features,) or (n_classes, n_features) + Coefficients without intercept term. + intercept : float or ndarray of shape (n_classes,) + Intercept terms. + raw_prediction : ndarray of shape (n_samples,) or \ + (n_samples, n_classes) + """ + if not self.base_loss.is_multiclass: + if self.fit_intercept: + intercept = coef[-1] + weights = coef[:-1] + else: + intercept = 0.0 + weights = coef + raw_prediction = X @ weights + intercept + else: + # reshape to (n_classes, n_dof) + if coef.ndim == 1: + weights = coef.reshape((self.base_loss.n_classes, -1), order="F") + else: + weights = coef + if self.fit_intercept: + intercept = weights[:, -1] + weights = weights[:, :-1] + else: + intercept = 0.0 + raw_prediction = X @ weights.T + intercept # ndarray, likely C-contiguous + + return weights, intercept, raw_prediction + + def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1): + """Compute the loss as sum over point-wise losses. + + Parameters + ---------- + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) + Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None + Sample weights. + l2_reg_strength : float, default=0.0 + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + loss : float + Sum of losses per sample plus penalty. + """ + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + loss = self.base_loss.loss( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + loss = loss.sum() + + norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights) + return loss + 0.5 * l2_reg_strength * norm2_w + + def loss_gradient( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): + """Computes the sum of loss and gradient w.r.t. coef. + + Parameters + ---------- + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) + Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None + Sample weights. + l2_reg_strength : float, default=0.0 + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + loss : float + Sum of losses per sample plus penalty. + + gradient : ndarray of shape coef.shape + The gradient of the loss. + """ + n_features, n_classes = X.shape[1], self.base_loss.n_classes + n_dof = n_features + int(self.fit_intercept) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + loss, grad_per_sample = self.base_loss.loss_gradient( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + loss = loss.sum() + + if not self.base_loss.is_multiclass: + loss += 0.5 * l2_reg_strength * (weights @ weights) + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights + if self.fit_intercept: + grad[-1] = grad_per_sample.sum() + else: + loss += 0.5 * l2_reg_strength * squared_norm(weights) + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") + # grad_per_sample.shape = (n_samples, n_classes) + grad[:, :n_features] = grad_per_sample.T @ X + l2_reg_strength * weights + if self.fit_intercept: + grad[:, -1] = grad_per_sample.sum(axis=0) + if coef.ndim == 1: + grad = grad.ravel(order="F") + + return loss, grad + + def gradient( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): + """Computes the gradient w.r.t. coef. + + Parameters + ---------- + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) + Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None + Sample weights. + l2_reg_strength : float, default=0.0 + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + gradient : ndarray of shape coef.shape + The gradient of the loss. + """ + n_features, n_classes = X.shape[1], self.base_loss.n_classes + n_dof = n_features + int(self.fit_intercept) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + grad_per_sample = self.base_loss.gradient( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + + if not self.base_loss.is_multiclass: + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights + if self.fit_intercept: + grad[-1] = grad_per_sample.sum() + return grad + else: + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") + # gradient.shape = (n_samples, n_classes) + grad[:, :n_features] = grad_per_sample.T @ X + l2_reg_strength * weights + if self.fit_intercept: + grad[:, -1] = grad_per_sample.sum(axis=0) + if coef.ndim == 1: + return grad.ravel(order="F") + else: + return grad + + def gradient_hessian_product( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): + """Computes gradient and hessp (hessian product function) w.r.t. coef. + + Parameters + ---------- + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) + Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None + Sample weights. + l2_reg_strength : float, default=0.0 + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + gradient : ndarray of shape coef.shape + The gradient of the loss. + + hessp : callable + Function that takes in a vector input of shape of gradient and + and returns matrix-vector product with hessian. + """ + (n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes + n_dof = n_features + int(self.fit_intercept) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + if not self.base_loss.is_multiclass: + gradient, hessian = self.base_loss.gradient_hessian( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ gradient + l2_reg_strength * weights + if self.fit_intercept: + grad[-1] = gradient.sum() + + # Precompute as much as possible: hX, hX_sum and hessian_sum + hessian_sum = hessian.sum() + if sparse.issparse(X): + hX = sparse.dia_matrix((hessian, 0), shape=(n_samples, n_samples)) @ X + else: + hX = hessian[:, np.newaxis] * X + + if self.fit_intercept: + # Calculate the double derivative with respect to intercept. + # Note: In case hX is sparse, hX.sum is a matrix object. + hX_sum = np.squeeze(np.asarray(hX.sum(axis=0))) + + # With intercept included and l2_reg_strength = 0, hessp returns + # res = (X, 1)' @ diag(h) @ (X, 1) @ s + # = (X, 1)' @ (hX @ s[:n_features], sum(h) * s[-1]) + # res[:n_features] = X' @ hX @ s[:n_features] + sum(h) * s[-1] + # res[-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1] + def hessp(s): + ret = np.empty_like(s) + if sparse.issparse(X): + ret[:n_features] = X.T @ (hX @ s[:n_features]) + else: + ret[:n_features] = np.linalg.multi_dot([X.T, hX, s[:n_features]]) + ret[:n_features] += l2_reg_strength * s[:n_features] + + if self.fit_intercept: + ret[:n_features] += s[-1] * hX_sum + ret[-1] = hX_sum @ s[:n_features] + hessian_sum * s[-1] + return ret + + else: + # Here we may safely assume HalfMultinomialLoss aka categorical + # cross-entropy. + # HalfMultinomialLoss computes only the diagonal part of the hessian, i.e. + # diagonal in the classes. Here, we want the matrix-vector product of the + # full hessian. Therefore, we call gradient_proba. + gradient, proba = self.base_loss.gradient_proba( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") + grad[:, :n_features] = gradient.T @ X + l2_reg_strength * weights + if self.fit_intercept: + grad[:, -1] = gradient.sum(axis=0) + + # Full hessian-vector product, i.e. not only the diagonal part of the + # hessian. Derivation with some index battle for input vector s: + # - sample index i + # - feature indices j, m + # - class indices k, l + # - 1_{k=l} is one if k=l else 0 + # - p_i_k is the (predicted) probability that sample i belongs to class k + # for all i: sum_k p_i_k = 1 + # - s_l_m is input vector for class l and feature m + # - X' = X transposed + # + # Note: Hessian with dropping most indices is just: + # X' @ p_k (1(k=l) - p_l) @ X + # + # result_{k j} = sum_{i, l, m} Hessian_{i, k j, m l} * s_l_m + # = sum_{i, l, m} (X')_{ji} * p_i_k * (1_{k=l} - p_i_l) + # * X_{im} s_l_m + # = sum_{i, m} (X')_{ji} * p_i_k + # * (X_{im} * s_k_m - sum_l p_i_l * X_{im} * s_l_m) + # + # See also https://github.com/scikit-learn/scikit-learn/pull/3646#discussion_r17461411 # noqa + def hessp(s): + s = s.reshape((n_classes, -1), order="F") # shape = (n_classes, n_dof) + if self.fit_intercept: + s_intercept = s[:, -1] + s = s[:, :-1] # shape = (n_classes, n_features) + else: + s_intercept = 0 + tmp = X @ s.T + s_intercept # X_{im} * s_k_m + tmp += (-proba * tmp).sum(axis=1)[:, np.newaxis] # - sum_l .. + tmp *= proba # * p_i_k + if sample_weight is not None: + tmp *= sample_weight[:, np.newaxis] + # hess_prod = empty_like(grad), but we ravel grad below and this + # function is run after that. + hess_prod = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") + hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s + if self.fit_intercept: + hess_prod[:, -1] = tmp.sum(axis=0) + if coef.ndim == 1: + return hess_prod.ravel(order="F") + else: + return hess_prod + + if coef.ndim == 1: + return grad.ravel(order="F"), hessp + + return grad, hessp diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index dd3a1a0c15a11..fed48fb07d924 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -14,17 +14,18 @@ import warnings import numpy as np -from scipy import optimize, sparse -from scipy.special import expit, logsumexp +from scipy import optimize from joblib import Parallel, effective_n_jobs from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator +from ._linear_loss import LinearModelLoss from ._sag import sag_solver +from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..preprocessing import LabelEncoder, LabelBinarizer from ..svm._base import _fit_liblinear from ..utils import check_array, check_consistent_length, compute_class_weight from ..utils import check_random_state -from ..utils.extmath import log_logistic, safe_sparse_dot, softmax, squared_norm +from ..utils.extmath import softmax from ..utils.extmath import row_norms from ..utils.optimize import _newton_cg, _check_optimize_result from ..utils.validation import check_is_fitted, _check_sample_weight @@ -41,392 +42,6 @@ ) -# .. some helper functions for logistic_regression_path .. -def _intercept_dot(w, X, y): - """Computes y * np.dot(X, w). - - It takes into consideration if the intercept should be fit or not. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - Returns - ------- - w : ndarray of shape (n_features,) - Coefficient vector without the intercept weight (w[-1]) if the - intercept should be fit. Unchanged otherwise. - - c : float - The intercept. - - yz : float - y * np.dot(X, w). - """ - c = 0.0 - if w.size == X.shape[1] + 1: - c = w[-1] - w = w[:-1] - - z = safe_sparse_dot(X, w) + c - yz = y * z - return w, c, yz - - -def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): - """Computes the logistic loss and gradient. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,), default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - out : float - Logistic loss. - - grad : ndarray of shape (n_features,) or (n_features + 1,) - Logistic gradient. - """ - n_samples, n_features = X.shape - grad = np.empty_like(w) - - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(n_samples) - - # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + 0.5 * alpha * np.dot(w, w) - - z = expit(yz) - z0 = sample_weight * (z - 1) * y - - grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w - - # Case where we fit the intercept. - if grad.shape[0] > n_features: - grad[-1] = z0.sum() - return out, grad - - -def _logistic_loss(w, X, y, alpha, sample_weight=None): - """Computes the logistic loss. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - out : float - Logistic loss. - """ - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(y.shape[0]) - - # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + 0.5 * alpha * np.dot(w, w) - return out - - -def _logistic_grad_hess(w, X, y, alpha, sample_weight=None): - """Computes the gradient and the Hessian, in the case of a logistic loss. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - grad : ndarray of shape (n_features,) or (n_features + 1,) - Logistic gradient. - - Hs : callable - Function that takes the gradient as a parameter and returns the - matrix product of the Hessian and gradient. - """ - n_samples, n_features = X.shape - grad = np.empty_like(w) - fit_intercept = grad.shape[0] > n_features - - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(y.shape[0]) - - z = expit(yz) - z0 = sample_weight * (z - 1) * y - - grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w - - # Case where we fit the intercept. - if fit_intercept: - grad[-1] = z0.sum() - - # The mat-vec product of the Hessian - d = sample_weight * z * (1 - z) - if sparse.issparse(X): - dX = safe_sparse_dot(sparse.dia_matrix((d, 0), shape=(n_samples, n_samples)), X) - else: - # Precompute as much as possible - dX = d[:, np.newaxis] * X - - if fit_intercept: - # Calculate the double derivative with respect to intercept - # In the case of sparse matrices this returns a matrix object. - dd_intercept = np.squeeze(np.array(dX.sum(axis=0))) - - def Hs(s): - ret = np.empty_like(s) - if sparse.issparse(X): - ret[:n_features] = X.T.dot(dX.dot(s[:n_features])) - else: - ret[:n_features] = np.linalg.multi_dot([X.T, dX, s[:n_features]]) - ret[:n_features] += alpha * s[:n_features] - - # For the fit intercept case. - if fit_intercept: - ret[:n_features] += s[-1] * dd_intercept - ret[-1] = dd_intercept.dot(s[:n_features]) - ret[-1] += d.sum() * s[-1] - return ret - - return grad, Hs - - -def _multinomial_loss(w, X, Y, alpha, sample_weight): - """Computes multinomial loss and class probabilities. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - loss : float - Multinomial loss. - - p : ndarray of shape (n_samples, n_classes) - Estimated class probabilities. - - w : ndarray of shape (n_classes, n_features) - Reshaped param vector excluding intercept terms. - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - n_classes = Y.shape[1] - n_features = X.shape[1] - fit_intercept = w.size == (n_classes * (n_features + 1)) - w = w.reshape(n_classes, -1) - sample_weight = sample_weight[:, np.newaxis] - if fit_intercept: - intercept = w[:, -1] - w = w[:, :-1] - else: - intercept = 0 - p = safe_sparse_dot(X, w.T) - p += intercept - p -= logsumexp(p, axis=1)[:, np.newaxis] - loss = -(sample_weight * Y * p).sum() - loss += 0.5 * alpha * squared_norm(w) - p = np.exp(p, p) - return loss, p, w - - -def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): - """Computes the multinomial loss, gradient and class probabilities. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - loss : float - Multinomial loss. - - grad : ndarray of shape (n_classes * n_features,) or \ - (n_classes * (n_features + 1),) - Ravelled gradient of the multinomial loss. - - p : ndarray of shape (n_samples, n_classes) - Estimated class probabilities - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - n_classes = Y.shape[1] - n_features = X.shape[1] - fit_intercept = w.size == n_classes * (n_features + 1) - grad = np.zeros((n_classes, n_features + bool(fit_intercept)), dtype=X.dtype) - loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight) - sample_weight = sample_weight[:, np.newaxis] - diff = sample_weight * (p - Y) - grad[:, :n_features] = safe_sparse_dot(diff.T, X) - grad[:, :n_features] += alpha * w - if fit_intercept: - grad[:, -1] = diff.sum(axis=0) - return loss, grad.ravel(), p - - -def _multinomial_grad_hess(w, X, Y, alpha, sample_weight): - """ - Computes the gradient and the Hessian, in the case of a multinomial loss. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - grad : ndarray of shape (n_classes * n_features,) or \ - (n_classes * (n_features + 1),) - Ravelled gradient of the multinomial loss. - - hessp : callable - Function that takes in a vector input of shape (n_classes * n_features) - or (n_classes * (n_features + 1)) and returns matrix-vector product - with hessian. - - References - ---------- - Barak A. Pearlmutter (1993). Fast Exact Multiplication by the Hessian. - http://www.bcl.hamilton.ie/~barak/papers/nc-hessian.pdf - """ - n_features = X.shape[1] - n_classes = Y.shape[1] - fit_intercept = w.size == (n_classes * (n_features + 1)) - - # `loss` is unused. Refactoring to avoid computing it does not - # significantly speed up the computation and decreases readability - loss, grad, p = _multinomial_loss_grad(w, X, Y, alpha, sample_weight) - sample_weight = sample_weight[:, np.newaxis] - - # Hessian-vector product derived by applying the R-operator on the gradient - # of the multinomial loss function. - def hessp(v): - v = v.reshape(n_classes, -1) - if fit_intercept: - inter_terms = v[:, -1] - v = v[:, :-1] - else: - inter_terms = 0 - # r_yhat holds the result of applying the R-operator on the multinomial - # estimator. - r_yhat = safe_sparse_dot(X, v.T) - r_yhat += inter_terms - r_yhat += (-p * r_yhat).sum(axis=1)[:, np.newaxis] - r_yhat *= p - r_yhat *= sample_weight - hessProd = np.zeros((n_classes, n_features + bool(fit_intercept))) - hessProd[:, :n_features] = safe_sparse_dot(r_yhat.T, X) - hessProd[:, :n_features] += v * alpha - if fit_intercept: - hessProd[:, -1] = r_yhat.sum(axis=0) - return hessProd.ravel() - - return grad, hessp - - def _check_solver(solver, penalty, dual): all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga"] if solver not in all_solvers: @@ -504,6 +119,7 @@ def _logistic_regression_path( max_squared_sum=None, sample_weight=None, l1_ratio=None, + n_threads=1, ): """Compute a Logistic Regression model for a list of regularization parameters. @@ -628,6 +244,9 @@ def _logistic_regression_path( to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a combination of L1 and L2. + n_threads : int, default=1 + Number of OpenMP threads to use. + Returns ------- coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1) @@ -695,12 +314,18 @@ def _logistic_regression_path( # multinomial case this is not necessary. if multi_class == "ovr": w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) - mask_classes = np.array([-1, 1]) mask = y == pos_class y_bin = np.ones(y.shape, dtype=X.dtype) - y_bin[~mask] = -1.0 - # for compute_class_weight + if solver in ["lbfgs", "newton-cg"]: + # HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead + # of in [-1, 1]. + mask_classes = np.array([0, 1]) + y_bin[~mask] = 0.0 + else: + mask_classes = np.array([-1, 1]) + y_bin[~mask] = -1.0 + # for compute_class_weight if class_weight == "balanced": class_weight_ = compute_class_weight( class_weight, classes=mask_classes, y=y_bin @@ -708,15 +333,19 @@ def _logistic_regression_path( sample_weight *= class_weight_[le.fit_transform(y_bin)] else: - if solver not in ["sag", "saga"]: + if solver in ["sag", "saga", "lbfgs", "newton-cg"]: + # SAG, lbfgs and newton-cg multinomial solvers need LabelEncoder, + # not LabelBinarizer, i.e. y as a 1d-array of integers. + # LabelEncoder also saves memory compared to LabelBinarizer, especially + # when n_classes is large. + le = LabelEncoder() + Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) + else: + # For liblinear solver, apply LabelBinarizer, i.e. y is one-hot encoded. lbin = LabelBinarizer() Y_multi = lbin.fit_transform(y) if Y_multi.shape[1] == 1: Y_multi = np.hstack([1 - Y_multi, Y_multi]) - else: - # SAG multinomial solver needs LabelEncoder, not LabelBinarizer - le = LabelEncoder() - Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) w0 = np.zeros( (classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype @@ -762,43 +391,45 @@ def _logistic_regression_path( w0[:, : coef.shape[1]] = coef if multi_class == "multinomial": - # scipy.optimize.minimize and newton-cg accepts only - # ravelled parameters. if solver in ["lbfgs", "newton-cg"]: - w0 = w0.ravel() + # scipy.optimize.minimize and newton-cg accept only ravelled parameters, + # i.e. 1d-arrays. LinearModelLoss expects classes to be contiguous and + # reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F"). + # As w0 is F-contiguous, ravel(order="F") also avoids a copy. + w0 = w0.ravel(order="F") + loss = LinearModelLoss( + base_loss=HalfMultinomialLoss(n_classes=classes.size), + fit_intercept=fit_intercept, + ) target = Y_multi - if solver == "lbfgs": - - def func(x, *args): - return _multinomial_loss_grad(x, *args)[0:2] - + if solver in "lbfgs": + func = loss.loss_gradient elif solver == "newton-cg": - - def func(x, *args): - return _multinomial_loss(x, *args)[0] - - def grad(x, *args): - return _multinomial_loss_grad(x, *args)[1] - - hess = _multinomial_grad_hess + func = loss.loss + grad = loss.gradient + hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": w0.T} else: target = y_bin if solver == "lbfgs": - func = _logistic_loss_and_grad + loss = LinearModelLoss( + base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept + ) + func = loss.loss_gradient elif solver == "newton-cg": - func = _logistic_loss - - def grad(x, *args): - return _logistic_loss_and_grad(x, *args)[1] - - hess = _logistic_grad_hess + loss = LinearModelLoss( + base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept + ) + func = loss.loss + grad = loss.gradient + hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): if solver == "lbfgs": + l2_reg_strength = 1.0 / C iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) ] @@ -807,7 +438,7 @@ def grad(x, *args): w0, method="L-BFGS-B", jac=True, - args=(X, target, 1.0 / C, sample_weight), + args=(X, target, sample_weight, l2_reg_strength, n_threads), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, ) n_iter_i = _check_optimize_result( @@ -818,7 +449,8 @@ def grad(x, *args): ) w0, loss = opt_res.x, opt_res.fun elif solver == "newton-cg": - args = (X, target, 1.0 / C, sample_weight) + l2_reg_strength = 1.0 / C + args = (X, target, sample_weight, l2_reg_strength, n_threads) w0, n_iter_i = _newton_cg( hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) @@ -885,7 +517,10 @@ def grad(x, *args): if multi_class == "multinomial": n_classes = max(2, classes.size) - multi_w0 = np.reshape(w0, (n_classes, -1)) + if solver in ["lbfgs", "newton-cg"]: + multi_w0 = np.reshape(w0, (n_classes, -1), order="F") + else: + multi_w0 = w0 if n_classes == 2: multi_w0 = multi_w0[1][np.newaxis, :] coefs.append(multi_w0.copy()) @@ -1584,6 +1219,21 @@ def fit(self, X, y, sample_weight=None): prefer = "threads" else: prefer = "processes" + + # TODO: Refactor this to avoid joblib parallelism entirely when doing binary + # and multinomial multiclass classification and use joblib only for the + # one-vs-rest multiclass case. + if ( + solver in ["lbfgs", "newton-cg"] + and len(classes_) == 1 + and effective_n_jobs(self.n_jobs) == 1 + ): + # In the future, we would like n_threads = _openmp_effective_n_threads() + # For the time being, we just do + n_threads = 1 + else: + n_threads = 1 + fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer=prefer)( path_func( X, @@ -1604,6 +1254,7 @@ def fit(self, X, y, sample_weight=None): penalty=penalty, max_squared_sum=max_squared_sum, sample_weight=sample_weight, + n_threads=n_threads, ) for class_, warm_start_coef_ in zip(classes_, warm_start_coef) ) diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py new file mode 100644 index 0000000000000..d4e20ad69ca8a --- /dev/null +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -0,0 +1,308 @@ +""" +Tests for LinearModelLoss + +Note that correctness of losses (which compose LinearModelLoss) is already well +covered in the _loss module. +""" +import pytest +import numpy as np +from numpy.testing import assert_allclose +from scipy import linalg, optimize, sparse + +from sklearn._loss.loss import ( + HalfBinomialLoss, + HalfMultinomialLoss, + HalfPoissonLoss, +) +from sklearn.datasets import make_low_rank_matrix +from sklearn.linear_model._linear_loss import LinearModelLoss +from sklearn.utils.extmath import squared_norm + + +# We do not need to test all losses, just what LinearModelLoss does on top of the +# base losses. +LOSSES = [HalfBinomialLoss, HalfMultinomialLoss, HalfPoissonLoss] + + +def random_X_y_coef( + linear_model_loss, n_samples, n_features, coef_bound=(-2, 2), seed=42 +): + """Random generate y, X and coef in valid range.""" + rng = np.random.RandomState(seed) + n_dof = n_features + linear_model_loss.fit_intercept + X = make_low_rank_matrix( + n_samples=n_samples, + n_features=n_features, + random_state=rng, + ) + + if linear_model_loss.base_loss.is_multiclass: + n_classes = linear_model_loss.base_loss.n_classes + coef = np.empty((n_classes, n_dof)) + coef.flat[:] = rng.uniform( + low=coef_bound[0], + high=coef_bound[1], + size=n_classes * n_dof, + ) + if linear_model_loss.fit_intercept: + raw_prediction = X @ coef[:, :-1].T + coef[:, -1] + else: + raw_prediction = X @ coef.T + proba = linear_model_loss.base_loss.link.inverse(raw_prediction) + + # y = rng.choice(np.arange(n_classes), p=proba) does not work. + # See https://stackoverflow.com/a/34190035/16761084 + def choice_vectorized(items, p): + s = p.cumsum(axis=1) + r = rng.rand(p.shape[0])[:, None] + k = (s < r).sum(axis=1) + return items[k] + + y = choice_vectorized(np.arange(n_classes), p=proba).astype(np.float64) + else: + coef = np.empty((n_dof,)) + coef.flat[:] = rng.uniform( + low=coef_bound[0], + high=coef_bound[1], + size=n_dof, + ) + if linear_model_loss.fit_intercept: + raw_prediction = X @ coef[:-1] + coef[-1] + else: + raw_prediction = X @ coef + y = linear_model_loss.base_loss.link.inverse( + raw_prediction + rng.uniform(low=-1, high=1, size=n_samples) + ) + + return X, y, coef + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +def test_loss_gradients_are_the_same( + base_loss, fit_intercept, sample_weight, l2_reg_strength +): + """Test that loss and gradient are the same across different functions.""" + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=fit_intercept) + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=10, n_features=5, seed=42 + ) + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + l1 = loss.loss( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g1 = loss.gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l2, g2 = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g3, h3 = loss.gradient_hessian_product( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + assert_allclose(l1, l2) + assert_allclose(g1, g2) + assert_allclose(g1, g3) + + # same for sparse X + X = sparse.csr_matrix(X) + l1_sp = loss.loss( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g1_sp = loss.gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l2_sp, g2_sp = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g3_sp, h3_sp = loss.gradient_hessian_product( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + assert_allclose(l1, l1_sp) + assert_allclose(l1, l2_sp) + assert_allclose(g1, g1_sp) + assert_allclose(g1, g2_sp) + assert_allclose(g1, g3_sp) + assert_allclose(h3(g1), h3_sp(g1_sp)) + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +@pytest.mark.parametrize("X_sparse", [False, True]) +def test_loss_gradients_hessp_intercept( + base_loss, sample_weight, l2_reg_strength, X_sparse +): + """Test that loss and gradient handle intercept correctly.""" + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=False) + loss_inter = LinearModelLoss(base_loss=base_loss(), fit_intercept=True) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + + X[:, -1] = 1 # make last column of 1 to mimic intercept term + X_inter = X[ + :, :-1 + ] # exclude intercept column as it is added automatically by loss_inter + + if X_sparse: + X = sparse.csr_matrix(X) + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + l, g = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + _, hessp = loss.gradient_hessian_product( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l_inter, g_inter = loss_inter.loss_gradient( + coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + _, hessp_inter = loss_inter.gradient_hessian_product( + coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + # Note, that intercept gets no L2 penalty. + assert l == pytest.approx( + l_inter + 0.5 * l2_reg_strength * squared_norm(coef.T[-1]) + ) + + g_inter_corrected = g_inter + g_inter_corrected.T[-1] += l2_reg_strength * coef.T[-1] + assert_allclose(g, g_inter_corrected) + + s = np.random.RandomState(42).randn(*coef.shape) + h = hessp(s) + h_inter = hessp_inter(s) + h_inter_corrected = h_inter + h_inter_corrected.T[-1] += l2_reg_strength * s.T[-1] + assert_allclose(h, h_inter_corrected) + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +def test_gradients_hessians_numerically( + base_loss, fit_intercept, sample_weight, l2_reg_strength +): + """Test gradients and hessians with numerical derivatives. + + Gradient should equal the numerical derivatives of the loss function. + Hessians should equal the numerical derivatives of gradients. + """ + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=fit_intercept) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + coef = coef.ravel(order="F") # this is important only for multinomial loss + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + # 1. Check gradients numerically + eps = 1e-6 + g, hessp = loss.gradient_hessian_product( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + # Use a trick to get central finite difference of accuracy 4 (five-point stencil) + # https://en.wikipedia.org/wiki/Numerical_differentiation + # https://en.wikipedia.org/wiki/Finite_difference_coefficient + # approx_g1 = (f(x + eps) - f(x - eps)) / (2*eps) + approx_g1 = optimize.approx_fprime( + coef, + lambda coef: loss.loss( + coef - eps, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ), + 2 * eps, + ) + # approx_g2 = (f(x + 2*eps) - f(x - 2*eps)) / (4*eps) + approx_g2 = optimize.approx_fprime( + coef, + lambda coef: loss.loss( + coef - 2 * eps, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ), + 4 * eps, + ) + # Five-point stencil approximation + # See: https://en.wikipedia.org/wiki/Five-point_stencil#1D_first_derivative + approx_g = (4 * approx_g1 - approx_g2) / 3 + assert_allclose(g, approx_g, rtol=1e-2, atol=1e-8) + + # 2. Check hessp numerically along the second direction of the gradient + vector = np.zeros_like(g) + vector[1] = 1 + hess_col = hessp(vector) + # Computation of the Hessian is particularly fragile to numerical errors when doing + # simple finite differences. Here we compute the grad along a path in the direction + # of the vector and then use a least-square regression to estimate the slope + eps = 1e-3 + d_x = np.linspace(-eps, eps, 30) + d_grad = np.array( + [ + loss.gradient( + coef + t * vector, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ) + for t in d_x + ] + ) + d_grad -= d_grad.mean(axis=0) + approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() + assert_allclose(approx_hess_col, hess_col, rtol=1e-3) + + +@pytest.mark.parametrize("fit_intercept", [False, True]) +def test_multinomial_coef_shape(fit_intercept): + """Test that multinomial LinearModelLoss respects shape of coef.""" + loss = LinearModelLoss(base_loss=HalfMultinomialLoss(), fit_intercept=fit_intercept) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + s = np.random.RandomState(42).randn(*coef.shape) + + l, g = loss.loss_gradient(coef, X, y) + g1 = loss.gradient(coef, X, y) + g2, hessp = loss.gradient_hessian_product(coef, X, y) + h = hessp(s) + assert g.shape == coef.shape + assert h.shape == coef.shape + assert_allclose(g, g1) + assert_allclose(g, g2) + + coef_r = coef.ravel(order="F") + s_r = s.ravel(order="F") + l_r, g_r = loss.loss_gradient(coef_r, X, y) + g1_r = loss.gradient(coef_r, X, y) + g2_r, hessp_r = loss.gradient_hessian_product(coef_r, X, y) + h_r = hessp_r(s_r) + assert g_r.shape == coef_r.shape + assert h_r.shape == coef_r.shape + assert_allclose(g_r, g1_r) + assert_allclose(g_r, g2_r) + + assert_allclose(g, g_r.reshape(loss.base_loss.n_classes, -1, order="F")) + assert_allclose(h, h_r.reshape(loss.base_loss.n_classes, -1, order="F")) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 3bdfec46e30dc..0c35795b8b642 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -5,8 +5,7 @@ import numpy as np from numpy.testing import assert_allclose, assert_almost_equal from numpy.testing import assert_array_almost_equal, assert_array_equal -import scipy.sparse as sp -from scipy import linalg, optimize, sparse +from scipy import sparse import pytest @@ -28,18 +27,14 @@ from sklearn.exceptions import ConvergenceWarning from sklearn.linear_model._logistic import ( - LogisticRegression, + _log_reg_scoring_path, _logistic_regression_path, + LogisticRegression, LogisticRegressionCV, - _logistic_loss_and_grad, - _logistic_grad_hess, - _multinomial_grad_hess, - _logistic_loss, - _log_reg_scoring_path, ) X = [[-1, 0], [0, 1], [1, 1]] -X_sp = sp.csr_matrix(X) +X_sp = sparse.csr_matrix(X) Y1 = [0, 1, 1] Y2 = [2, 1, 0] iris = load_iris() @@ -317,10 +312,10 @@ def test_sparsify(): pred_d_d = clf.decision_function(iris.data) clf.sparsify() - assert sp.issparse(clf.coef_) + assert sparse.issparse(clf.coef_) pred_s_d = clf.decision_function(iris.data) - sp_data = sp.coo_matrix(iris.data) + sp_data = sparse.coo_matrix(iris.data) pred_s_s = clf.decision_function(sp_data) clf.densify() @@ -501,82 +496,6 @@ def test_liblinear_dual_random_state(): assert_array_almost_equal(lr1.coef_, lr3.coef_) -def test_logistic_loss_and_grad(): - X_ref, y = make_classification(n_samples=20, random_state=0) - n_features = X_ref.shape[1] - - X_sp = X_ref.copy() - X_sp[X_sp < 0.1] = 0 - X_sp = sp.csr_matrix(X_sp) - for X in (X_ref, X_sp): - w = np.zeros(n_features) - - # First check that our derivation of the grad is correct - loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1.0) - approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.0)[0], 1e-3 - ) - assert_array_almost_equal(grad, approx_grad, decimal=2) - - # Second check that our intercept implementation is good - w = np.zeros(n_features + 1) - loss_interp, grad_interp = _logistic_loss_and_grad(w, X, y, alpha=1.0) - assert_array_almost_equal(loss, loss_interp) - - approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.0)[0], 1e-3 - ) - assert_array_almost_equal(grad_interp, approx_grad, decimal=2) - - -def test_logistic_grad_hess(): - rng = np.random.RandomState(0) - n_samples, n_features = 50, 5 - X_ref = rng.randn(n_samples, n_features) - y = np.sign(X_ref.dot(5 * rng.randn(n_features))) - X_ref -= X_ref.mean() - X_ref /= X_ref.std() - X_sp = X_ref.copy() - X_sp[X_sp < 0.1] = 0 - X_sp = sp.csr_matrix(X_sp) - for X in (X_ref, X_sp): - w = np.full(n_features, 0.1) - - # First check that _logistic_grad_hess is consistent - # with _logistic_loss_and_grad - loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1.0) - grad_2, hess = _logistic_grad_hess(w, X, y, alpha=1.0) - assert_array_almost_equal(grad, grad_2) - - # Now check our hessian along the second direction of the grad - vector = np.zeros_like(grad) - vector[1] = 1 - hess_col = hess(vector) - - # Computation of the Hessian is particularly fragile to numerical - # errors when doing simple finite differences. Here we compute the - # grad along a path in the direction of the vector and then use a - # least-square regression to estimate the slope - e = 1e-3 - d_x = np.linspace(-e, e, 30) - d_grad = np.array( - [_logistic_loss_and_grad(w + t * vector, X, y, alpha=1.0)[1] for t in d_x] - ) - - d_grad -= d_grad.mean(axis=0) - approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() - - assert_array_almost_equal(approx_hess_col, hess_col, decimal=3) - - # Second check that our intercept implementation is good - w = np.zeros(n_features + 1) - loss_interp, grad_interp = _logistic_loss_and_grad(w, X, y, alpha=1.0) - loss_interp_2 = _logistic_loss(w, X, y, alpha=1.0) - grad_interp_2, hess = _logistic_grad_hess(w, X, y, alpha=1.0) - assert_array_almost_equal(loss_interp, loss_interp_2) - assert_array_almost_equal(grad_interp, grad_interp_2) - - def test_logistic_cv(): # test for LogisticRegressionCV object n_samples, n_features = 50, 5 @@ -690,7 +609,7 @@ def test_multinomial_logistic_regression_string_inputs(): def test_logistic_cv_sparse(): X, y = make_classification(n_samples=50, n_features=5, random_state=0) X[X < 1.0] = 0.0 - csr = sp.csr_matrix(X) + csr = sparse.csr_matrix(X) clf = LogisticRegressionCV() clf.fit(X, y) @@ -701,40 +620,6 @@ def test_logistic_cv_sparse(): assert clfs.C_ == clf.C_ -def test_intercept_logistic_helper(): - n_samples, n_features = 10, 5 - X, y = make_classification( - n_samples=n_samples, n_features=n_features, random_state=0 - ) - - # Fit intercept case. - alpha = 1.0 - w = np.ones(n_features + 1) - grad_interp, hess_interp = _logistic_grad_hess(w, X, y, alpha) - loss_interp = _logistic_loss(w, X, y, alpha) - - # Do not fit intercept. This can be considered equivalent to adding - # a feature vector of ones, i.e column of one vectors. - X_ = np.hstack((X, np.ones(10)[:, np.newaxis])) - grad, hess = _logistic_grad_hess(w, X_, y, alpha) - loss = _logistic_loss(w, X_, y, alpha) - - # In the fit_intercept=False case, the feature vector of ones is - # penalized. This should be taken care of. - assert_almost_equal(loss_interp + 0.5 * (w[-1] ** 2), loss) - - # Check gradient. - assert_array_almost_equal(grad_interp[:n_features], grad[:n_features]) - assert_almost_equal(grad_interp[-1] + alpha * w[-1], grad[-1]) - - rng = np.random.RandomState(0) - grad = rng.rand(n_features + 1) - hess_interp = hess_interp(grad) - hess = hess(grad) - assert_array_almost_equal(hess_interp[:n_features], hess[:n_features]) - assert_almost_equal(hess_interp[-1] + alpha * grad[-1], hess[-1]) - - def test_ovr_multinomial_iris(): # Test that OvR and multinomial are correct using the iris dataset. train, target = iris.data, iris.target @@ -1107,41 +992,6 @@ def test_logistic_regression_multinomial(): assert_allclose(clf_path.intercept_, ref_i.intercept_, rtol=2e-2) -def test_multinomial_grad_hess(): - rng = np.random.RandomState(0) - n_samples, n_features, n_classes = 100, 5, 3 - X = rng.randn(n_samples, n_features) - w = rng.rand(n_classes, n_features) - Y = np.zeros((n_samples, n_classes)) - ind = np.argmax(np.dot(X, w.T), axis=1) - Y[range(0, n_samples), ind] = 1 - w = w.ravel() - sample_weights = np.ones(X.shape[0]) - grad, hessp = _multinomial_grad_hess( - w, X, Y, alpha=1.0, sample_weight=sample_weights - ) - # extract first column of hessian matrix - vec = np.zeros(n_features * n_classes) - vec[0] = 1 - hess_col = hessp(vec) - - # Estimate hessian using least squares as done in - # test_logistic_grad_hess - e = 1e-3 - d_x = np.linspace(-e, e, 30) - d_grad = np.array( - [ - _multinomial_grad_hess( - w + t * vec, X, Y, alpha=1.0, sample_weight=sample_weights - )[0] - for t in d_x - ] - ) - d_grad -= d_grad.mean(axis=0) - approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() - assert_array_almost_equal(hess_col, approx_hess_col) - - def test_liblinear_decision_function_zero(): # Test negative prediction when decision_function values are zero. # Liblinear predicts the positive class when decision_function values @@ -1533,8 +1383,8 @@ def test_dtype_match(solver, multi_class, fit_intercept): y_32 = np.array(Y1).astype(np.float32) X_64 = np.array(X).astype(np.float64) y_64 = np.array(Y1).astype(np.float64) - X_sparse_32 = sp.csr_matrix(X, dtype=np.float32) - X_sparse_64 = sp.csr_matrix(X, dtype=np.float64) + X_sparse_32 = sparse.csr_matrix(X, dtype=np.float32) + X_sparse_64 = sparse.csr_matrix(X, dtype=np.float64) solver_tol = 5e-4 lr_templ = LogisticRegression( @@ -2236,7 +2086,7 @@ def test_large_sparse_matrix(solver): # Non-regression test for pull-request #21093. # generate sparse matrix with int64 indices - X = sp.rand(20, 10, format="csr") + X = sparse.rand(20, 10, format="csr") for attr in ["indices", "indptr"]: setattr(X, attr, getattr(X, attr).astype("int64")) y = np.random.randint(2, size=X.shape[0]) diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index 88df6621f8176..d3a27c4088ab7 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -10,11 +10,12 @@ import scipy.sparse as sp from scipy.special import logsumexp +from sklearn._loss.loss import HalfMultinomialLoss +from sklearn.linear_model._linear_loss import LinearModelLoss from sklearn.linear_model._sag import get_auto_step_size from sklearn.linear_model._sag_fast import _multinomial_grad_loss_all_samples from sklearn.linear_model import LogisticRegression, Ridge from sklearn.linear_model._base import make_dataset -from sklearn.linear_model._logistic import _multinomial_loss_grad from sklearn.utils.extmath import row_norms from sklearn.utils._testing import assert_almost_equal @@ -933,13 +934,14 @@ def test_multinomial_loss(): dataset, weights, intercept, n_samples, n_features, n_classes ) # compute loss and gradient like in multinomial LogisticRegression - lbin = LabelBinarizer() - Y_bin = lbin.fit_transform(y) - weights_intercept = np.vstack((weights, intercept)).T.ravel() - loss_2, grad_2, _ = _multinomial_loss_grad( - weights_intercept, X, Y_bin, 0.0, sample_weights + loss = LinearModelLoss( + base_loss=HalfMultinomialLoss(n_classes=n_classes), + fit_intercept=True, + ) + weights_intercept = np.vstack((weights, intercept)).T + loss_2, grad_2 = loss.loss_gradient( + weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights ) - grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T # comparison @@ -951,7 +953,7 @@ def test_multinomial_loss_ground_truth(): # n_samples, n_features, n_classes = 4, 2, 3 n_classes = 3 X = np.array([[1.1, 2.2], [2.2, -4.4], [3.3, -2.2], [1.1, 1.1]]) - y = np.array([0, 1, 2, 0]) + y = np.array([0, 1, 2, 0], dtype=np.float64) lbin = LabelBinarizer() Y_bin = lbin.fit_transform(y) @@ -966,11 +968,14 @@ def test_multinomial_loss_ground_truth(): diff = sample_weights[:, np.newaxis] * (np.exp(p) - Y_bin) grad_1 = np.dot(X.T, diff) - weights_intercept = np.vstack((weights, intercept)).T.ravel() - loss_2, grad_2, _ = _multinomial_loss_grad( - weights_intercept, X, Y_bin, 0.0, sample_weights + loss = LinearModelLoss( + base_loss=HalfMultinomialLoss(n_classes=n_classes), + fit_intercept=True, + ) + weights_intercept = np.vstack((weights, intercept)).T + loss_2, grad_2 = loss.loss_gradient( + weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights ) - grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T assert_almost_equal(loss_1, loss_2)