From bf2b0cf4cf2e08ef732318b22768b27b7ad79156 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 1 Jan 2021 20:44:01 +0100 Subject: [PATCH 01/44] ENH replace loss in linear logistic regression --- sklearn/_loss/loss.py | 2 +- sklearn/linear_model/_linear_loss.py | 359 ++++++++++++++++++++ sklearn/linear_model/_logistic.py | 73 ++-- sklearn/linear_model/tests/test_logistic.py | 126 ++++--- 4 files changed, 483 insertions(+), 77 deletions(-) create mode 100644 sklearn/linear_model/_linear_loss.py diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index a394bd9de06c3..d883c0e1bd190 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -119,7 +119,7 @@ class BaseLoss: differentiable = True is_multiclass = False - def __init__(self, closs, link, n_classes=1): + def __init__(self, closs, link, n_classes=None): self.closs = closs self.link = link self.approx_hessian = False diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py new file mode 100644 index 0000000000000..f3ad51260b2f1 --- /dev/null +++ b/sklearn/linear_model/_linear_loss.py @@ -0,0 +1,359 @@ +""" +Loss functions for linear models with raw_prediction = X @ coef +""" +import numpy as np +from scipy import sparse +from ..utils.extmath import squared_norm + + +class LinearLoss: + """General class for loss functions with raw_prediction = X @ coef. + + The loss is the sum of per sample losses and includes an L2 term:: + + loss = sum_i s_i loss(y_i, X_i @ coef + intercept) + 1/2 * alpha * ||coef||_2^2 + + with sample weights s_i=1 if sample_weight=None. + + Gradient and hessian, for simplicity without intercept, are:: + + gradient = X.T @ loss.gradient + alpha * coef + hessian = X.T @ diag(loss.hessian) @ X + alpha * identity + + Conventions: + if fit_intercept: + n_dof = n_features + 1 + else: + n_dof = n_features + + if loss.is_multiclass: + coef.shape = (n_classes * n_dof,) + intercept.shape = (n_classes) + else: + coef.shape = (n_dof,) + intercept.shape = (1) or it is a float + + The intercept term is at the end of the coef array: + if loss.is_multiclass: + coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] + else: + coef[-1] + + Note: If the average loss per sample is wanted instead of the sum of the + loss per sample, one can simply use a rescaled sample_weight such that + sum(sample_weight) = 1. + + Parameters + ---------- + _loss : instance of a loss function from sklearn._loss. + fit_intercept : bool + """ + + def __init__(self, loss, fit_intercept): + self._loss = loss + self.fit_intercept = fit_intercept + + def _w_intercept_raw(self, coef, X): + """Helper function to get coefficients, intercept and raw_prediction. + + Parameters + ---------- + coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + Coefficients of a linear model. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + Returns + ------- + w : ndarray of shape (n_features,) or (n_classes, n_features) + Coefficients without intercept term. + intercept : float or ndarray of shape (n_classes,) + Intercept terms. + raw_prediction : ndarray of shape (n_samples,) or \ + (n_samples, n_classes) + """ + if not self._loss.is_multiclass: + if self.fit_intercept: + intercept = coef[-1] + w = coef[:-1] + else: + intercept = 0.0 + w = coef + else: + # reshape to (n_classes, n_dof) + w = coef.reshape(self._loss.n_classes, -1) + if self.fit_intercept: + intercept = w[:, -1] + w = w[:, :-1] + else: + intercept = 0.0 + + raw_prediction = X @ w.T + intercept + return w, intercept, raw_prediction + + def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + """Compute the loss as sum over point-wise losses. + + Parameters + ---------- + coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + Coefficients of a linear model. + y : C/F-contiguous array of shape (n_samples,) + Observed, true target values. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + sample_weight : None or C/F-contiguous array of shape (n_samples,) + Sample weights. + alpha: float + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + loss : float + Sum of losses per sample plus penalty. + """ + w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + loss = self._loss.loss( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + loss = loss.sum() + + if w.ndim == 1: + return loss + 0.5 * alpha * (w @ w) + else: + return loss + 0.5 * alpha * squared_norm(w) + + def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + """Computes the sum/average of loss and gradient. + + Parameters + ---------- + coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + Coefficients of a linear model. + y : C/F-contiguous array of shape (n_samples,) + Observed, true target values. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + sample_weight : None or C/F-contiguous array of shape (n_samples,) + Sample weights. + alpha: float + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + loss : float + Sum of losses per sample plus penalty. + + gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) + The gradient of the loss as ravelled array. + """ + n_features, n_classes = X.shape[1], self._loss.n_classes + n_dof = n_features + self.fit_intercept + w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + loss, gradient = self._loss.loss_gradient( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + loss = loss.sum() + + if not self._loss.is_multiclass: + loss += 0.5 * alpha * (w @ w) + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ gradient + alpha * w + if self.fit_intercept: + grad[-1] = gradient.sum() + return loss, grad + else: + loss += 0.5 * alpha * squared_norm(w) + grad = np.empty((n_classes, n_dof), dtype=X.dtype) + # gradient.shape = (n_samples, n_classes) + grad[:, :n_features] = gradient.T @ X + alpha * w + if self.fit_intercept: + grad[:, -1] = gradient.sum(axis=0) + return loss, grad.ravel() + + def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + """Computes the gradient. + + Parameters + ---------- + coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + Coefficients of a linear model. + y : C/F-contiguous array of shape (n_samples,) + Observed, true target values. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + sample_weight : None or C/F-contiguous array of shape (n_samples,) + Sample weights. + alpha: float + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) + The gradient of the loss as ravelled array. + """ + n_features, n_classes = X.shape[1], self._loss.n_classes + n_dof = n_features + self.fit_intercept + w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + gradient = self._loss.gradient( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + + if not self._loss.is_multiclass: + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ gradient + alpha * w + if self.fit_intercept: + grad[-1] = gradient.sum() + return grad + else: + grad = np.empty((n_classes, n_dof), dtype=X.dtype) + # gradient.shape = (n_samples, n_classes) + grad[:, :n_features] = gradient.T @ X + alpha * w + if self.fit_intercept: + grad[:, -1] = gradient.sum(axis=0) + return grad.ravel() + + def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + """Computes gradient and hessp (hessian product function). + + Parameters + ---------- + coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + Coefficients of a linear model. + y : C/F-contiguous array of shape (n_samples,) + Observed, true target values. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + sample_weight : None or C/F-contiguous array of shape (n_samples,) + Sample weights. + alpha: float + L2 regularization strength + n_threads : int, default=1 + Number of OpenMP threads to use. + + Returns + ------- + gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) + The gradient of the loss as ravelled array. + + hessp : callable + Function that takes in a vector input of shape of gradient and + and returns matrix-vector product with hessian. + """ + (n_samples, n_features), n_classes = X.shape, self._loss.n_classes + w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + + if not self._loss.is_multiclass: + gradient, hessian = self._loss.gradient_hessian( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + grad = np.empty_like(coef, dtype=X.dtype) + grad[:n_features] = X.T @ gradient + alpha * w + if self.fit_intercept: + grad[-1] = gradient.sum() + + # Precompute as much as possible: hX, hh_intercept and hsum + hsum = hessian.sum() + if sparse.issparse(X): + hX = sparse.dia_matrix((hessian, 0), shape=(n_samples, n_samples)) @ X + else: + # TODO: This may consume a lot of memory. Better preallocate the array. + hX = hessian[:, np.newaxis] * X + + if self.fit_intercept: + # Calculate the double derivative with respect to intercept. + # Note: In case hX is sparse, hX.sum is a matrix object. + hh_intercept = np.squeeze(np.array(hX.sum(axis=0))) + + # With intercept included and alpha = 0, hessp returns + # res = (X, 1)' @ diag(h) @ (X, 1) @ s + # = (X, 1)' @ (hX @ s[:n_features], sum(h) * s[-1]) + # res[:n_features] = X' @ hX @ s[:n_features] + sum(h) * s[-1] + # res[:-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1] + def hessp(s): + ret = np.empty_like(s) + ret[:n_features] = X.T @ (hX @ s[:n_features]) + ret[:n_features] += alpha * s[:n_features] + + if self.fit_intercept: + ret[:n_features] += s[-1] * hh_intercept + ret[-1] = hh_intercept @ s[:n_features] + hsum * s[-1] + return ret + + else: + # Here we may safely assume HalfMultinomialLoss aka categorical + # cross-entropy. + # HalfMultinomialLoss computes only the diagonal part of the hessian, i.e. + # diagonal in the classes. Here, we want the matrix-vector product of the + # full hessian. Therefore, we call gradient_proba. + gradient, proba = self._loss.gradient_proba( + y_true=y, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + n_threads=n_threads, + ) + grad = np.empty_like(coef.reshape(n_classes, -1), dtype=X.dtype) + grad[:, :n_features] = gradient.T @ X + alpha * w + if self.fit_intercept: + grad[:, -1] = gradient.sum(axis=0) + + # Full hessian-vector product, i.e. not only the diagonal part of the + # hessian. Derivation with some index battle for inupt vector s: + # - sample index i + # - feature indices j, m + # - class indices k, l + # - 1_{k=l} is one if k=l else 0 + # - p_i_k is class probability of sample i and class k + # - s_l_m is input vector for class l and feature m + # - X' = X transposed + # + # Note: Hessian with dropping most indices is just: + # X' @ p_k (1(k=l) - p_l) @ X + # + # result_{k j} = sum_{i, l, m} Hessian_{i, k j, m l} * s_l_m + # = sum_{i, l, m} (X')_{ji} * p_i_k * (1_{k=l} - p_i_l) + # * X_{im} s_l_m + # = sum_{i, m} (X')_{ji} * p_i_k + # * (X_{im} * s_k_m - sum_l p_i_l * X_{im} * s_l_m) + # + # See also https://github.com/scikit-learn/scikit-learn/pull/3646#discussion_r17461411 # noqa + def hessp(s): + s = s.reshape(n_classes, -1) # shape = (n_classes, n_dof) + if self.fit_intercept: + s_intercept = s[:, -1] + s = s[:, :-1] + else: + s_intercept = 0 + tmp = X @ s.T + s_intercept # X_{im} * s_k_m + tmp += (-proba * tmp).sum(axis=1)[:, np.newaxis] # - sum_l .. + tmp *= proba # * p_i_k + if sample_weight is not None: + tmp *= sample_weight[:, np.newaxis] + hessProd = np.empty_like(grad) + hessProd[:, :n_features] = tmp.T @ X + alpha * s + if self.fit_intercept: + hessProd[:, -1] = tmp.sum(axis=0) + return hessProd.ravel() + + return grad.ravel(), hessp diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 08e71edbc69ab..06c93016e9352 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -19,11 +19,14 @@ from joblib import Parallel, effective_n_jobs from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator +from ._linear_loss import LinearLoss from ._sag import sag_solver +from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..preprocessing import LabelEncoder, LabelBinarizer from ..svm._base import _fit_liblinear from ..utils import check_array, check_consistent_length, compute_class_weight from ..utils import check_random_state +from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils.extmath import log_logistic, safe_sparse_dot, softmax, squared_norm from ..utils.extmath import row_norms from ..utils.optimize import _newton_cg, _check_optimize_result @@ -505,6 +508,7 @@ def _logistic_regression_path( max_squared_sum=None, sample_weight=None, l1_ratio=None, + n_threads=1, ): """Compute a Logistic Regression model for a list of regularization parameters. @@ -629,6 +633,9 @@ def _logistic_regression_path( to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a combination of L1 and L2. + n_threads : int, default=1 + Number of OpenMP threads to use. + Returns ------- coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1) @@ -696,12 +703,16 @@ def _logistic_regression_path( # multinomial case this is not necessary. if multi_class == "ovr": w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) - mask_classes = np.array([-1, 1]) mask = y == pos_class y_bin = np.ones(y.shape, dtype=X.dtype) - y_bin[~mask] = -1.0 - # for compute_class_weight + if solver in ["lbfgs", "newton-cg"]: + mask_classes = np.array([0, 1]) + y_bin[~mask] = 0.0 + else: + mask_classes = np.array([-1, 1]) + y_bin[~mask] = -1.0 + # for compute_class_weight if class_weight == "balanced": class_weight_ = compute_class_weight( class_weight, classes=mask_classes, y=y_bin @@ -709,15 +720,17 @@ def _logistic_regression_path( sample_weight *= class_weight_[le.fit_transform(y_bin)] else: - if solver not in ["sag", "saga"]: + if solver in ["sag", "saga", "lbfgs", "newton-cg"]: + # SAG, lbfgs and newton-cg multinomial solvers need LabelEncoder, + # not LabelBinarizer, i.e. y is mapped to integers. + le = LabelEncoder() + Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) + else: + # Apply LabelBinarizer, i.e. y is one-hot encoded. lbin = LabelBinarizer() Y_multi = lbin.fit_transform(y) if Y_multi.shape[1] == 1: Y_multi = np.hstack([1 - Y_multi, Y_multi]) - else: - # SAG multinomial solver needs LabelEncoder, not LabelBinarizer - le = LabelEncoder() - Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) w0 = np.zeros( (classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype @@ -767,33 +780,28 @@ def _logistic_regression_path( # ravelled parameters. if solver in ["lbfgs", "newton-cg"]: w0 = w0.ravel() + loss = LinearLoss( + loss=HalfMultinomialLoss(n_classes=classes.size), + fit_intercept=fit_intercept, + ) target = Y_multi - if solver == "lbfgs": - - def func(x, *args): - return _multinomial_loss_grad(x, *args)[0:2] - + if solver in "lbfgs": + func = loss.loss_gradient elif solver == "newton-cg": - - def func(x, *args): - return _multinomial_loss(x, *args)[0] - - def grad(x, *args): - return _multinomial_loss_grad(x, *args)[1] - - hess = _multinomial_grad_hess + func = loss.loss + grad = loss.gradient + hess = loss.gradient_hessp # hess = [gradient, hessp] warm_start_sag = {"coef": w0.T} else: target = y_bin if solver == "lbfgs": - func = _logistic_loss_and_grad + loss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + func = loss.loss_gradient elif solver == "newton-cg": - func = _logistic_loss - - def grad(x, *args): - return _logistic_loss_and_grad(x, *args)[1] - - hess = _logistic_grad_hess + loss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + func = loss.loss + grad = loss.gradient + hess = loss.gradient_hessp # hess = [gradient, hessp] warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} coefs = list() @@ -808,7 +816,7 @@ def grad(x, *args): w0, method="L-BFGS-B", jac=True, - args=(X, target, 1.0 / C, sample_weight), + args=(X, target, sample_weight, 1.0 / C, n_threads), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, ) n_iter_i = _check_optimize_result( @@ -819,7 +827,7 @@ def grad(x, *args): ) w0, loss = opt_res.x, opt_res.fun elif solver == "newton-cg": - args = (X, target, 1.0 / C, sample_weight) + args = (X, target, sample_weight, 1.0 / C, n_threads) w0, n_iter_i = _newton_cg( hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) @@ -1586,6 +1594,10 @@ def fit(self, X, y, sample_weight=None): prefer = "threads" else: prefer = "processes" + if solver in ["lbfgs", "newton-cg"] and len(classes_) == 1: + n_threads = _openmp_effective_n_threads() + else: + n_threads = 1 fold_coefs_ = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, @@ -1610,6 +1622,7 @@ def fit(self, X, y, sample_weight=None): penalty=penalty, max_squared_sum=max_squared_sum, sample_weight=sample_weight, + n_threads=n_threads, ) for class_, warm_start_coef_ in zip(classes_, warm_start_coef) ) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 1171613eb3718..b5c7dc54ac4de 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -9,6 +9,7 @@ import pytest +from sklearn._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from sklearn.base import clone from sklearn.datasets import load_iris, make_classification from sklearn.metrics import log_loss @@ -26,15 +27,12 @@ from sklearn.utils._testing import skip_if_no_parallel from sklearn.exceptions import ConvergenceWarning +from sklearn.linear_model._linear_loss import LinearLoss from sklearn.linear_model._logistic import ( - LogisticRegression, + _log_reg_scoring_path, _logistic_regression_path, + LogisticRegression, LogisticRegressionCV, - _logistic_loss_and_grad, - _logistic_grad_hess, - _multinomial_grad_hess, - _logistic_loss, - _log_reg_scoring_path, ) X = [[-1, 0], [0, 1], [1, 1]] @@ -501,56 +499,82 @@ def test_liblinear_dual_random_state(): def test_logistic_loss_and_grad(): - X_ref, y = make_classification(n_samples=20, random_state=0) + n_samples, n_features = 20, 20 + alpha = 1.0 + X_ref, y = make_classification( + n_samples=n_samples, n_features=n_features, random_state=0 + ) + # make last column of 1 to mimic intercept term + X_ref[:, -1] = 1 + X_ref_inter = X_ref[:, :-1] # exclude intercept column + y = y.astype(np.float64) n_features = X_ref.shape[1] X_sp = X_ref.copy() X_sp[X_sp < 0.1] = 0 X_sp = sp.csr_matrix(X_sp) - for X in (X_ref, X_sp): - w = np.zeros(n_features) - + X_sp_inter = sp.lil_matrix(X_sp) # supports slicing + X_sp_inter = sp.csr_matrix(X_sp_inter[:, :-1]) + for X, X_inter in ((X_ref, X_ref_inter), (X_sp, X_sp_inter)): + w = np.ones(n_features) + # make an intercept of 0.5 + w[-1] = 0.5 + + logloss = LinearLoss( + loss=HalfBinomialLoss(), + fit_intercept=False, + ) # First check that our derivation of the grad is correct - loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1.0) + loss, grad = logloss.loss_gradient(w, X, y, alpha=alpha) approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.0)[0], 1e-3 + w, lambda w: logloss.loss(w, X, y, alpha=alpha), 1e-3 ) assert_array_almost_equal(grad, approx_grad, decimal=2) # Second check that our intercept implementation is good - w = np.zeros(n_features + 1) - loss_interp, grad_interp = _logistic_loss_and_grad(w, X, y, alpha=1.0) - assert_array_almost_equal(loss, loss_interp) + logloss = LinearLoss( + loss=HalfBinomialLoss(), + fit_intercept=True, + ) + loss_inter, grad_inter = logloss.loss_gradient(w, X_inter, y, alpha=alpha) + # Note, that intercept gets no L2 penalty. + assert loss == pytest.approx(loss_inter + 0.5 * alpha * w[-1] ** 2) approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.0)[0], 1e-3 + w, lambda w: logloss.loss(w, X_inter, y, alpha=alpha), 1e-3 ) - assert_array_almost_equal(grad_interp, approx_grad, decimal=2) + assert_array_almost_equal(grad_inter, approx_grad, decimal=2) def test_logistic_grad_hess(): rng = np.random.RandomState(0) n_samples, n_features = 50, 5 + alpha = 1.0 X_ref = rng.randn(n_samples, n_features) y = np.sign(X_ref.dot(5 * rng.randn(n_features))) X_ref -= X_ref.mean() X_ref /= X_ref.std() + # make last column of 1 to mimic intercept term + X_ref[:, :-1] = 1 X_sp = X_ref.copy() X_sp[X_sp < 0.1] = 0 X_sp = sp.csr_matrix(X_sp) for X in (X_ref, X_sp): w = np.full(n_features, 0.1) + logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=False) - # First check that _logistic_grad_hess is consistent - # with _logistic_loss_and_grad - loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1.0) - grad_2, hess = _logistic_grad_hess(w, X, y, alpha=1.0) + # First check that gradients from gradient(), loss_gradient() and + # gradient_hessp() are consistent + grad = logloss.gradient(w, X, y, alpha=alpha) + loss, grad_2 = logloss.loss_gradient(w, X, y, alpha=alpha) + grad_3, hessp = logloss.gradient_hessp(w, X, y, alpha=alpha) assert_array_almost_equal(grad, grad_2) + assert_array_almost_equal(grad, grad_3) # Now check our hessian along the second direction of the grad vector = np.zeros_like(grad) vector[1] = 1 - hess_col = hess(vector) + hess_col = hessp(vector) # Computation of the Hessian is particularly fragile to numerical # errors when doing simple finite differences. Here we compute the @@ -559,7 +583,7 @@ def test_logistic_grad_hess(): e = 1e-3 d_x = np.linspace(-e, e, 30) d_grad = np.array( - [_logistic_loss_and_grad(w + t * vector, X, y, alpha=1.0)[1] for t in d_x] + [logloss.gradient(w + t * vector, X, y, alpha=alpha) for t in d_x] ) d_grad -= d_grad.mean(axis=0) @@ -569,11 +593,12 @@ def test_logistic_grad_hess(): # Second check that our intercept implementation is good w = np.zeros(n_features + 1) - loss_interp, grad_interp = _logistic_loss_and_grad(w, X, y, alpha=1.0) - loss_interp_2 = _logistic_loss(w, X, y, alpha=1.0) - grad_interp_2, hess = _logistic_grad_hess(w, X, y, alpha=1.0) - assert_array_almost_equal(loss_interp, loss_interp_2) - assert_array_almost_equal(grad_interp, grad_interp_2) + logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) + loss_inter, grad_inter = logloss.loss_gradient(w, X, y, alpha=alpha) + loss_inter_2 = logloss.loss(w, X, y, alpha=alpha) + grad_inter_2, hess = logloss.gradient_hessp(w, X, y, alpha=alpha) + assert_array_almost_equal(loss_inter, loss_inter_2) + assert_array_almost_equal(grad_inter, grad_inter_2) def test_logistic_cv(): @@ -705,31 +730,37 @@ def test_intercept_logistic_helper(): X, y = make_classification( n_samples=n_samples, n_features=n_features, random_state=0 ) + y = y.astype(np.float64) # Fit intercept case. + logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) alpha = 1.0 w = np.ones(n_features + 1) - grad_interp, hess_interp = _logistic_grad_hess(w, X, y, alpha) - loss_interp = _logistic_loss(w, X, y, alpha) + grad_inter, hess_inter = logloss.gradient_hessp(w, X, y, alpha=alpha) + loss_inter = logloss.loss(w, X, y, alpha=alpha) # Do not fit intercept. This can be considered equivalent to adding - # a feature vector of ones, i.e column of one vectors. - X_ = np.hstack((X, np.ones(10)[:, np.newaxis])) - grad, hess = _logistic_grad_hess(w, X_, y, alpha) - loss = _logistic_loss(w, X_, y, alpha) + # a feature vector of ones, i.e last column vector's elements are all one. + X_ = np.hstack((X, np.ones(n_samples)[:, np.newaxis])) + logloss = LinearLoss( + loss=HalfBinomialLoss(), + fit_intercept=False, + ) + grad, hessp = logloss.gradient_hessp(w, X_, y, alpha=alpha) + loss = logloss.loss(w, X_, y, alpha=alpha) # In the fit_intercept=False case, the feature vector of ones is # penalized. This should be taken care of. - assert_almost_equal(loss_interp + 0.5 * (w[-1] ** 2), loss) + assert_almost_equal(loss_inter + 0.5 * (w[-1] ** 2), loss) # Check gradient. - assert_array_almost_equal(grad_interp[:n_features], grad[:n_features]) - assert_almost_equal(grad_interp[-1] + alpha * w[-1], grad[-1]) + assert_array_almost_equal(grad_inter[:n_features], grad[:n_features]) + assert_almost_equal(grad_inter[-1] + alpha * w[-1], grad[-1]) rng = np.random.RandomState(0) grad = rng.rand(n_features + 1) - hess_interp = hess_interp(grad) - hess = hess(grad) + hess_interp = hess_inter(grad) + hess = hessp(grad) assert_array_almost_equal(hess_interp[:n_features], hess[:n_features]) assert_almost_equal(hess_interp[-1] + alpha * grad[-1], hess[-1]) @@ -1121,13 +1152,16 @@ def test_multinomial_grad_hess(): n_samples, n_features, n_classes = 100, 5, 3 X = rng.randn(n_samples, n_features) w = rng.rand(n_classes, n_features) - Y = np.zeros((n_samples, n_classes)) - ind = np.argmax(np.dot(X, w.T), axis=1) - Y[range(0, n_samples), ind] = 1 + y = np.argmax(np.dot(X, w.T), axis=1).astype(X.dtype) w = w.ravel() sample_weights = np.ones(X.shape[0]) - grad, hessp = _multinomial_grad_hess( - w, X, Y, alpha=1.0, sample_weight=sample_weights + alpha = 1.0 + multinomial = LinearLoss( + loss=HalfMultinomialLoss(n_classes=n_classes), + fit_intercept=False, + ) + grad, hessp = multinomial.gradient_hessp( + w, X, y, alpha=alpha, sample_weight=sample_weights ) # extract first column of hessian matrix vec = np.zeros(n_features * n_classes) @@ -1140,8 +1174,8 @@ def test_multinomial_grad_hess(): d_x = np.linspace(-e, e, 30) d_grad = np.array( [ - _multinomial_grad_hess( - w + t * vec, X, Y, alpha=1.0, sample_weight=sample_weights + multinomial.gradient_hessp( + w + t * vec, X, y, alpha=alpha, sample_weight=sample_weights )[0] for t in d_x ] From f74bdde0f630844dc284ebedb798af9de83d84b3 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 1 Jan 2021 22:57:38 +0100 Subject: [PATCH 02/44] MNT remove logistic regression's own loss functions --- sklearn/linear_model/_logistic.py | 391 +------------------------ sklearn/linear_model/tests/test_sag.py | 23 +- 2 files changed, 17 insertions(+), 397 deletions(-) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 06c93016e9352..12fdd7eb7e5b7 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -14,8 +14,7 @@ import warnings import numpy as np -from scipy import optimize, sparse -from scipy.special import expit, logsumexp +from scipy import optimize from joblib import Parallel, effective_n_jobs from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator @@ -27,7 +26,7 @@ from ..utils import check_array, check_consistent_length, compute_class_weight from ..utils import check_random_state from ..utils._openmp_helpers import _openmp_effective_n_threads -from ..utils.extmath import log_logistic, safe_sparse_dot, softmax, squared_norm +from ..utils.extmath import softmax from ..utils.extmath import row_norms from ..utils.optimize import _newton_cg, _check_optimize_result from ..utils.validation import check_is_fitted, _check_sample_weight @@ -45,392 +44,6 @@ ) -# .. some helper functions for logistic_regression_path .. -def _intercept_dot(w, X, y): - """Computes y * np.dot(X, w). - - It takes into consideration if the intercept should be fit or not. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - Returns - ------- - w : ndarray of shape (n_features,) - Coefficient vector without the intercept weight (w[-1]) if the - intercept should be fit. Unchanged otherwise. - - c : float - The intercept. - - yz : float - y * np.dot(X, w). - """ - c = 0.0 - if w.size == X.shape[1] + 1: - c = w[-1] - w = w[:-1] - - z = safe_sparse_dot(X, w) + c - yz = y * z - return w, c, yz - - -def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): - """Computes the logistic loss and gradient. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,), default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - out : float - Logistic loss. - - grad : ndarray of shape (n_features,) or (n_features + 1,) - Logistic gradient. - """ - n_samples, n_features = X.shape - grad = np.empty_like(w) - - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(n_samples) - - # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + 0.5 * alpha * np.dot(w, w) - - z = expit(yz) - z0 = sample_weight * (z - 1) * y - - grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w - - # Case where we fit the intercept. - if grad.shape[0] > n_features: - grad[-1] = z0.sum() - return out, grad - - -def _logistic_loss(w, X, y, alpha, sample_weight=None): - """Computes the logistic loss. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - out : float - Logistic loss. - """ - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(y.shape[0]) - - # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + 0.5 * alpha * np.dot(w, w) - return out - - -def _logistic_grad_hess(w, X, y, alpha, sample_weight=None): - """Computes the gradient and the Hessian, in the case of a logistic loss. - - Parameters - ---------- - w : ndarray of shape (n_features,) or (n_features + 1,) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - y : ndarray of shape (n_samples,) - Array of labels. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. - - Returns - ------- - grad : ndarray of shape (n_features,) or (n_features + 1,) - Logistic gradient. - - Hs : callable - Function that takes the gradient as a parameter and returns the - matrix product of the Hessian and gradient. - """ - n_samples, n_features = X.shape - grad = np.empty_like(w) - fit_intercept = grad.shape[0] > n_features - - w, c, yz = _intercept_dot(w, X, y) - - if sample_weight is None: - sample_weight = np.ones(y.shape[0]) - - z = expit(yz) - z0 = sample_weight * (z - 1) * y - - grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w - - # Case where we fit the intercept. - if fit_intercept: - grad[-1] = z0.sum() - - # The mat-vec product of the Hessian - d = sample_weight * z * (1 - z) - if sparse.issparse(X): - dX = safe_sparse_dot(sparse.dia_matrix((d, 0), shape=(n_samples, n_samples)), X) - else: - # Precompute as much as possible - dX = d[:, np.newaxis] * X - - if fit_intercept: - # Calculate the double derivative with respect to intercept - # In the case of sparse matrices this returns a matrix object. - dd_intercept = np.squeeze(np.array(dX.sum(axis=0))) - - def Hs(s): - ret = np.empty_like(s) - if sparse.issparse(X): - ret[:n_features] = X.T.dot(dX.dot(s[:n_features])) - else: - ret[:n_features] = np.linalg.multi_dot([X.T, dX, s[:n_features]]) - ret[:n_features] += alpha * s[:n_features] - - # For the fit intercept case. - if fit_intercept: - ret[:n_features] += s[-1] * dd_intercept - ret[-1] = dd_intercept.dot(s[:n_features]) - ret[-1] += d.sum() * s[-1] - return ret - - return grad, Hs - - -def _multinomial_loss(w, X, Y, alpha, sample_weight): - """Computes multinomial loss and class probabilities. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - loss : float - Multinomial loss. - - p : ndarray of shape (n_samples, n_classes) - Estimated class probabilities. - - w : ndarray of shape (n_classes, n_features) - Reshaped param vector excluding intercept terms. - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - n_classes = Y.shape[1] - n_features = X.shape[1] - fit_intercept = w.size == (n_classes * (n_features + 1)) - w = w.reshape(n_classes, -1) - sample_weight = sample_weight[:, np.newaxis] - if fit_intercept: - intercept = w[:, -1] - w = w[:, :-1] - else: - intercept = 0 - p = safe_sparse_dot(X, w.T) - p += intercept - p -= logsumexp(p, axis=1)[:, np.newaxis] - loss = -(sample_weight * Y * p).sum() - loss += 0.5 * alpha * squared_norm(w) - p = np.exp(p, p) - return loss, p, w - - -def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): - """Computes the multinomial loss, gradient and class probabilities. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - loss : float - Multinomial loss. - - grad : ndarray of shape (n_classes * n_features,) or \ - (n_classes * (n_features + 1),) - Ravelled gradient of the multinomial loss. - - p : ndarray of shape (n_samples, n_classes) - Estimated class probabilities - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - n_classes = Y.shape[1] - n_features = X.shape[1] - fit_intercept = w.size == n_classes * (n_features + 1) - grad = np.zeros((n_classes, n_features + bool(fit_intercept)), dtype=X.dtype) - loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight) - sample_weight = sample_weight[:, np.newaxis] - diff = sample_weight * (p - Y) - grad[:, :n_features] = safe_sparse_dot(diff.T, X) - grad[:, :n_features] += alpha * w - if fit_intercept: - grad[:, -1] = diff.sum(axis=0) - return loss, grad.ravel(), p - - -def _multinomial_grad_hess(w, X, Y, alpha, sample_weight): - """ - Computes the gradient and the Hessian, in the case of a multinomial loss. - - Parameters - ---------- - w : ndarray of shape (n_classes * n_features,) or - (n_classes * (n_features + 1),) - Coefficient vector. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training data. - - Y : ndarray of shape (n_samples, n_classes) - Transformed labels according to the output of LabelBinarizer. - - alpha : float - Regularization parameter. alpha is equal to 1 / C. - - sample_weight : array-like of shape (n_samples,) - Array of weights that are assigned to individual samples. - - Returns - ------- - grad : ndarray of shape (n_classes * n_features,) or \ - (n_classes * (n_features + 1),) - Ravelled gradient of the multinomial loss. - - hessp : callable - Function that takes in a vector input of shape (n_classes * n_features) - or (n_classes * (n_features + 1)) and returns matrix-vector product - with hessian. - - References - ---------- - Barak A. Pearlmutter (1993). Fast Exact Multiplication by the Hessian. - http://www.bcl.hamilton.ie/~barak/papers/nc-hessian.pdf - """ - n_features = X.shape[1] - n_classes = Y.shape[1] - fit_intercept = w.size == (n_classes * (n_features + 1)) - - # `loss` is unused. Refactoring to avoid computing it does not - # significantly speed up the computation and decreases readability - loss, grad, p = _multinomial_loss_grad(w, X, Y, alpha, sample_weight) - sample_weight = sample_weight[:, np.newaxis] - - # Hessian-vector product derived by applying the R-operator on the gradient - # of the multinomial loss function. - def hessp(v): - v = v.reshape(n_classes, -1) - if fit_intercept: - inter_terms = v[:, -1] - v = v[:, :-1] - else: - inter_terms = 0 - # r_yhat holds the result of applying the R-operator on the multinomial - # estimator. - r_yhat = safe_sparse_dot(X, v.T) - r_yhat += inter_terms - r_yhat += (-p * r_yhat).sum(axis=1)[:, np.newaxis] - r_yhat *= p - r_yhat *= sample_weight - hessProd = np.zeros((n_classes, n_features + bool(fit_intercept))) - hessProd[:, :n_features] = safe_sparse_dot(r_yhat.T, X) - hessProd[:, :n_features] += v * alpha - if fit_intercept: - hessProd[:, -1] = r_yhat.sum(axis=0) - return hessProd.ravel() - - return grad, hessp - - def _check_solver(solver, penalty, dual): all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga"] if solver not in all_solvers: diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index 88df6621f8176..935936affe408 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -10,11 +10,12 @@ import scipy.sparse as sp from scipy.special import logsumexp +from sklearn._loss.loss import HalfMultinomialLoss +from sklearn.linear_model._linear_loss import LinearLoss from sklearn.linear_model._sag import get_auto_step_size from sklearn.linear_model._sag_fast import _multinomial_grad_loss_all_samples from sklearn.linear_model import LogisticRegression, Ridge from sklearn.linear_model._base import make_dataset -from sklearn.linear_model._logistic import _multinomial_loss_grad from sklearn.utils.extmath import row_norms from sklearn.utils._testing import assert_almost_equal @@ -933,11 +934,13 @@ def test_multinomial_loss(): dataset, weights, intercept, n_samples, n_features, n_classes ) # compute loss and gradient like in multinomial LogisticRegression - lbin = LabelBinarizer() - Y_bin = lbin.fit_transform(y) + loss = LinearLoss( + loss=HalfMultinomialLoss(n_classes=n_classes), + fit_intercept=True, + ) weights_intercept = np.vstack((weights, intercept)).T.ravel() - loss_2, grad_2, _ = _multinomial_loss_grad( - weights_intercept, X, Y_bin, 0.0, sample_weights + loss_2, grad_2 = loss.loss_gradient( + weights_intercept, X, y, alpha=0.0, sample_weight=sample_weights ) grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T @@ -951,7 +954,7 @@ def test_multinomial_loss_ground_truth(): # n_samples, n_features, n_classes = 4, 2, 3 n_classes = 3 X = np.array([[1.1, 2.2], [2.2, -4.4], [3.3, -2.2], [1.1, 1.1]]) - y = np.array([0, 1, 2, 0]) + y = np.array([0, 1, 2, 0], dtype=np.float64) lbin = LabelBinarizer() Y_bin = lbin.fit_transform(y) @@ -966,9 +969,13 @@ def test_multinomial_loss_ground_truth(): diff = sample_weights[:, np.newaxis] * (np.exp(p) - Y_bin) grad_1 = np.dot(X.T, diff) + loss = LinearLoss( + loss=HalfMultinomialLoss(n_classes=n_classes), + fit_intercept=True, + ) weights_intercept = np.vstack((weights, intercept)).T.ravel() - loss_2, grad_2, _ = _multinomial_loss_grad( - weights_intercept, X, Y_bin, 0.0, sample_weights + loss_2, grad_2 = loss.loss_gradient( + weights_intercept, X, y, alpha=0.0, sample_weight=sample_weights ) grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T From 5596ab25b688d4fcb92ff582c4c87d09477548e3 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 28 Nov 2021 15:14:01 +0100 Subject: [PATCH 03/44] CLN remove comment --- sklearn/linear_model/_linear_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index f3ad51260b2f1..6aa75030f6e4f 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -278,7 +278,6 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) if sparse.issparse(X): hX = sparse.dia_matrix((hessian, 0), shape=(n_samples, n_samples)) @ X else: - # TODO: This may consume a lot of memory. Better preallocate the array. hX = hessian[:, np.newaxis] * X if self.fit_intercept: From b6a96ce53ef3c6c52351c62b07cc28368cc0c5b8 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 28 Nov 2021 15:22:28 +0100 Subject: [PATCH 04/44] DOC add whatsnew --- doc/whats_new/v1.1.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 6a5b2d226cabe..b4f97f0b4139f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -182,6 +182,11 @@ Changelog estimator of the noise variance cannot be computed. :pr:`21481` by :user:`Guillaume Lemaitre ` +- |Enhancement| :class:`~linear_model.LogisticRegression` is + a bit faster, for binary as well as for multiclass problems thanks the new + private loss function module. + :pr:`20567` and :pr:`21808` by :user:`Christian Lorentzen `. + - |Fix| :class:`linear_model.LassoLarsIC` now correctly computes AIC and BIC. An error is now raised when `n_features > n_samples` and when the noise variance is not provided. From 9803d8a3f835d0df6d569f982fd3f5b0453ec1fe Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 28 Nov 2021 15:26:49 +0100 Subject: [PATCH 05/44] DOC more precise whatsnew --- doc/whats_new/v1.1.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index b4f97f0b4139f..6862e858e3bb2 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -182,9 +182,9 @@ Changelog estimator of the noise variance cannot be computed. :pr:`21481` by :user:`Guillaume Lemaitre ` -- |Enhancement| :class:`~linear_model.LogisticRegression` is - a bit faster, for binary as well as for multiclass problems thanks the new - private loss function module. +- |Enhancement| :class:`~linear_model.LogisticRegression` is a bit faster for + ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary as well as for multiclass + problems thanks to the new private loss function module. :pr:`20567` and :pr:`21808` by :user:`Christian Lorentzen `. - |Fix| :class:`linear_model.LassoLarsIC` now correctly computes AIC From c9e211fff8202a45a19d9c73d2894a6275231c33 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 28 Nov 2021 16:22:45 +0100 Subject: [PATCH 06/44] CLN restore improvements of #19571 --- sklearn/linear_model/_linear_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 6aa75030f6e4f..df0a68a956e6b 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -292,7 +292,10 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) # res[:-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1] def hessp(s): ret = np.empty_like(s) - ret[:n_features] = X.T @ (hX @ s[:n_features]) + if sparse.issparse(X): + ret[:n_features] = X.T @ (hX @ s[:n_features]) + else: + ret[:n_features] = np.linalg.multi_dot([X.T, hX, s[:n_features]]) ret[:n_features] += alpha * s[:n_features] if self.fit_intercept: From d36c17bddcd21b2c0a19892cca323a1151addbe7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 1 Dec 2021 14:25:09 +0100 Subject: [PATCH 07/44] ENH improve fit time by separating mat-vec in multiclass --- sklearn/linear_model/_linear_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index df0a68a956e6b..de3af572ec6e1 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -79,6 +79,7 @@ def _w_intercept_raw(self, coef, X): else: intercept = 0.0 w = coef + raw_prediction = X @ w + intercept else: # reshape to (n_classes, n_dof) w = coef.reshape(self._loss.n_classes, -1) @@ -87,8 +88,8 @@ def _w_intercept_raw(self, coef, X): w = w[:, :-1] else: intercept = 0.0 + raw_prediction = X @ w.T + intercept - raw_prediction = X @ w.T + intercept return w, intercept, raw_prediction def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): From faa16ec24c80264ace19b94aec2199861b28feb7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 1 Dec 2021 21:54:58 +0100 Subject: [PATCH 08/44] DOC update whatsnew --- doc/whats_new/v1.1.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 1939fdc58c932..ccfb41a05afab 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -197,9 +197,10 @@ Changelog :pr:`21481` by :user:`Guillaume Lemaitre ` - |Enhancement| :class:`~linear_model.LogisticRegression` is a bit faster for - ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary as well as for multiclass - problems thanks to the new private loss function module. - :pr:`20567` and :pr:`21808` by :user:`Christian Lorentzen `. + ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for + multiclass problems thanks to the new private loss function module. + :pr:`21808`, :pr:`20567` and :pr:`21814` by + :user:`Christian Lorentzen `. - |Fix| :class:`linear_model.LassoLarsIC` now correctly computes AIC and BIC. An error is now raised when `n_features > n_samples` and From e4d443252806b7876cf3f8e66191a608c095ec8b Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 1 Dec 2021 21:56:18 +0100 Subject: [PATCH 09/44] not only a bit ;-) --- doc/whats_new/v1.1.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index ccfb41a05afab..3ac8de2007a5f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -196,7 +196,7 @@ Changelog estimator of the noise variance cannot be computed. :pr:`21481` by :user:`Guillaume Lemaitre ` -- |Enhancement| :class:`~linear_model.LogisticRegression` is a bit faster for +- |Enhancement| :class:`~linear_model.LogisticRegression` is faster for ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for multiclass problems thanks to the new private loss function module. :pr:`21808`, :pr:`20567` and :pr:`21814` by From 3deafabad48a073a1273b939aee2cdd6fc9fc282 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 1 Dec 2021 22:01:48 +0100 Subject: [PATCH 10/44] DOC note memory benefit for multiclass case --- doc/whats_new/v1.1.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 3ac8de2007a5f..bc785aa83525e 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -198,7 +198,10 @@ Changelog - |Enhancement| :class:`~linear_model.LogisticRegression` is faster for ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for - multiclass problems thanks to the new private loss function module. + multiclass problems thanks to the new private loss function module. In the multiclass + case, also the memory consumptions is reduced for these solvers as the target is now + label encoded (mapped to integers) instead of label binarized (one-hot encoded). The + more classes, the larger the benefit. :pr:`21808`, :pr:`20567` and :pr:`21814` by :user:`Christian Lorentzen `. From 9d10861e841da6736f6e894e11876fdfe5d00117 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 4 Dec 2021 14:15:21 +0100 Subject: [PATCH 11/44] trigger CI From 3dba3f88e8ceacc39436a36cf2604cff6f785b61 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 6 Dec 2021 20:43:43 +0100 Subject: [PATCH 12/44] trigger CI From e588e33aadd042d95e0c631615e1aba04aa0aecc Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 10 Dec 2021 08:55:35 +0100 Subject: [PATCH 13/44] CLN rename variable to hess_prod --- sklearn/linear_model/_linear_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index de3af572ec6e1..eabfe3590998b 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -353,10 +353,10 @@ def hessp(s): tmp *= proba # * p_i_k if sample_weight is not None: tmp *= sample_weight[:, np.newaxis] - hessProd = np.empty_like(grad) - hessProd[:, :n_features] = tmp.T @ X + alpha * s + hess_prod = np.empty_like(grad) + hess_prod[:, :n_features] = tmp.T @ X + alpha * s if self.fit_intercept: - hessProd[:, -1] = tmp.sum(axis=0) - return hessProd.ravel() + hess_prod[:, -1] = tmp.sum(axis=0) + return hess_prod.ravel() return grad.ravel(), hessp From 31046482bdeba5b6518adb19f826c8fa68f372ee Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 15 Dec 2021 19:36:47 +0100 Subject: [PATCH 14/44] DOC address reviewer comments --- doc/whats_new/v1.1.rst | 6 +++--- sklearn/linear_model/_linear_loss.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 34f71732257ee..19f525da4af64 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -153,7 +153,7 @@ Changelog - |Fix| :class:`feature_extraction.FeatureHasher` now validates input parameters in `transform` instead of `__init__`. :pr:`21573` by :user:`Hannah Bohle ` and :user:`Maren Westermann `. - + - |API| :func:`decomposition.FastICA` now supports unit variance for whitening. The default value of its `whiten` argument will change from `True` (which behaves like `'arbitrary-variance'`) to `'unit-variance'` in version 1.3. @@ -240,7 +240,7 @@ Changelog - |Enhancement| :class:`~linear_model.LogisticRegression` is faster for ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for multiclass problems thanks to the new private loss function module. In the multiclass - case, also the memory consumptions is reduced for these solvers as the target is now + case, also the memory consumption is reduced for these solvers as the target is now label encoded (mapped to integers) instead of label binarized (one-hot encoded). The more classes, the larger the benefit. :pr:`21808`, :pr:`20567` and :pr:`21814` by @@ -288,7 +288,7 @@ Changelog all the models and all the splits failed. :pr:`21026` by :user:`Loïc Estève `. - |Fix| :class:`model_selection.GridSearchCV`, :class:`model_selection.HalvingGridSearchCV` - now validate input parameters in `fit` instead of `__init__`. + now validate input parameters in `fit` instead of `__init__`. :pr:`21880` by :user:`Mrinal Tyagi `. :mod:`sklearn.pipeline` diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index eabfe3590998b..694ff11a6c8b5 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -28,16 +28,15 @@ class LinearLoss: if loss.is_multiclass: coef.shape = (n_classes * n_dof,) - intercept.shape = (n_classes) else: coef.shape = (n_dof,) - intercept.shape = (1) or it is a float The intercept term is at the end of the coef array: if loss.is_multiclass: - coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] + intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] + intercept.shape = (n_classes,) else: - coef[-1] + intercept = coef[-1] Note: If the average loss per sample is wanted instead of the sum of the loss per sample, one can simply use a rescaled sample_weight such that From 639c16c54dcd1bf3529d1161e948d52dc14d50b1 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 15 Dec 2021 20:05:06 +0100 Subject: [PATCH 15/44] CLN remove C/F for 1d arrays --- sklearn/linear_model/_linear_loss.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 694ff11a6c8b5..e276a3bc198f0 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -98,11 +98,11 @@ def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): ---------- coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) Coefficients of a linear model. - y : C/F-contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or C/F-contiguous array of shape (n_samples,) + sample_weight : None or contiguous array of shape (n_samples,) Sample weights. alpha: float L2 regularization strength @@ -136,11 +136,11 @@ def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): ---------- coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) Coefficients of a linear model. - y : C/F-contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or C/F-contiguous array of shape (n_samples,) + sample_weight : None or contiguous array of shape (n_samples,) Sample weights. alpha: float L2 regularization strength @@ -190,11 +190,11 @@ def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): ---------- coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) Coefficients of a linear model. - y : C/F-contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or C/F-contiguous array of shape (n_samples,) + sample_weight : None or contiguous array of shape (n_samples,) Sample weights. alpha: float L2 regularization strength @@ -238,11 +238,11 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) ---------- coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) Coefficients of a linear model. - y : C/F-contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or C/F-contiguous array of shape (n_samples,) + sample_weight : None or contiguous array of shape (n_samples,) Sample weights. alpha: float L2 regularization strength From b7230e1a22d391307a68d7a63430aeea5be55fe0 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 15 Dec 2021 21:46:27 +0100 Subject: [PATCH 16/44] CLN rename to gradient_per_sample --- sklearn/linear_model/_linear_loss.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index e276a3bc198f0..98c1bac254273 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -159,7 +159,7 @@ def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): n_dof = n_features + self.fit_intercept w, intercept, raw_prediction = self._w_intercept_raw(coef, X) - loss, gradient = self._loss.loss_gradient( + loss, gradient_per_sample = self._loss.loss_gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -170,17 +170,17 @@ def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): if not self._loss.is_multiclass: loss += 0.5 * alpha * (w @ w) grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient + alpha * w + grad[:n_features] = X.T @ gradient_per_sample + alpha * w if self.fit_intercept: - grad[-1] = gradient.sum() + grad[-1] = gradient_per_sample.sum() return loss, grad else: loss += 0.5 * alpha * squared_norm(w) grad = np.empty((n_classes, n_dof), dtype=X.dtype) # gradient.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient.T @ X + alpha * w + grad[:, :n_features] = gradient_per_sample.T @ X + alpha * w if self.fit_intercept: - grad[:, -1] = gradient.sum(axis=0) + grad[:, -1] = gradient_per_sample.sum(axis=0) return loss, grad.ravel() def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): @@ -210,7 +210,7 @@ def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): n_dof = n_features + self.fit_intercept w, intercept, raw_prediction = self._w_intercept_raw(coef, X) - gradient = self._loss.gradient( + gradient_per_sample = self._loss.gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -219,16 +219,16 @@ def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): if not self._loss.is_multiclass: grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient + alpha * w + grad[:n_features] = X.T @ gradient_per_sample + alpha * w if self.fit_intercept: - grad[-1] = gradient.sum() + grad[-1] = gradient_per_sample.sum() return grad else: grad = np.empty((n_classes, n_dof), dtype=X.dtype) # gradient.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient.T @ X + alpha * w + grad[:, :n_features] = gradient_per_sample.T @ X + alpha * w if self.fit_intercept: - grad[:, -1] = gradient.sum(axis=0) + grad[:, -1] = gradient_per_sample.sum(axis=0) return grad.ravel() def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): From 17c674d05e0c30d45dd1d2c89b477304033811f5 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 15 Dec 2021 21:51:04 +0100 Subject: [PATCH 17/44] CLN rename alpha to l2_reg_strength --- sklearn/linear_model/_linear_loss.py | 55 ++++++++++++--------- sklearn/linear_model/tests/test_logistic.py | 36 +++++++------- 2 files changed, 50 insertions(+), 41 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 98c1bac254273..32d8d9774f6c0 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -11,14 +11,15 @@ class LinearLoss: The loss is the sum of per sample losses and includes an L2 term:: - loss = sum_i s_i loss(y_i, X_i @ coef + intercept) + 1/2 * alpha * ||coef||_2^2 + loss = sum_i s_i loss(y_i, X_i @ coef + intercept) + + 1/2 * l2_reg_strength * ||coef||_2^2 with sample weights s_i=1 if sample_weight=None. Gradient and hessian, for simplicity without intercept, are:: - gradient = X.T @ loss.gradient + alpha * coef - hessian = X.T @ diag(loss.hessian) @ X + alpha * identity + gradient = X.T @ loss.gradient + l2_reg_strength * coef + hessian = X.T @ diag(loss.hessian) @ X + l2_reg_strength * identity Conventions: if fit_intercept: @@ -91,7 +92,7 @@ def _w_intercept_raw(self, coef, X): return w, intercept, raw_prediction - def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1): """Compute the loss as sum over point-wise losses. Parameters @@ -104,7 +105,7 @@ def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): Training data. sample_weight : None or contiguous array of shape (n_samples,) Sample weights. - alpha: float + l2_reg_strength: float L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -125,11 +126,13 @@ def loss(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): loss = loss.sum() if w.ndim == 1: - return loss + 0.5 * alpha * (w @ w) + return loss + 0.5 * l2_reg_strength * (w @ w) else: - return loss + 0.5 * alpha * squared_norm(w) + return loss + 0.5 * l2_reg_strength * squared_norm(w) - def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + def loss_gradient( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): """Computes the sum/average of loss and gradient. Parameters @@ -142,7 +145,7 @@ def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): Training data. sample_weight : None or contiguous array of shape (n_samples,) Sample weights. - alpha: float + l2_reg_strength: float L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -168,22 +171,24 @@ def loss_gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): loss = loss.sum() if not self._loss.is_multiclass: - loss += 0.5 * alpha * (w @ w) + loss += 0.5 * l2_reg_strength * (w @ w) grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient_per_sample + alpha * w + grad[:n_features] = X.T @ gradient_per_sample + l2_reg_strength * w if self.fit_intercept: grad[-1] = gradient_per_sample.sum() return loss, grad else: - loss += 0.5 * alpha * squared_norm(w) + loss += 0.5 * l2_reg_strength * squared_norm(w) grad = np.empty((n_classes, n_dof), dtype=X.dtype) # gradient.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient_per_sample.T @ X + alpha * w + grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient_per_sample.sum(axis=0) return loss, grad.ravel() - def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + def gradient( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): """Computes the gradient. Parameters @@ -196,7 +201,7 @@ def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): Training data. sample_weight : None or contiguous array of shape (n_samples,) Sample weights. - alpha: float + l2_reg_strength: float L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -219,19 +224,21 @@ def gradient(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): if not self._loss.is_multiclass: grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient_per_sample + alpha * w + grad[:n_features] = X.T @ gradient_per_sample + l2_reg_strength * w if self.fit_intercept: grad[-1] = gradient_per_sample.sum() return grad else: grad = np.empty((n_classes, n_dof), dtype=X.dtype) # gradient.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient_per_sample.T @ X + alpha * w + grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient_per_sample.sum(axis=0) return grad.ravel() - def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1): + def gradient_hessp( + self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + ): """Computes gradient and hessp (hessian product function). Parameters @@ -244,7 +251,7 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) Training data. sample_weight : None or contiguous array of shape (n_samples,) Sample weights. - alpha: float + l2_reg_strength: float L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -269,7 +276,7 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) n_threads=n_threads, ) grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient + alpha * w + grad[:n_features] = X.T @ gradient + l2_reg_strength * w if self.fit_intercept: grad[-1] = gradient.sum() @@ -285,7 +292,7 @@ def gradient_hessp(self, coef, X, y, sample_weight=None, alpha=0.0, n_threads=1) # Note: In case hX is sparse, hX.sum is a matrix object. hh_intercept = np.squeeze(np.array(hX.sum(axis=0))) - # With intercept included and alpha = 0, hessp returns + # With intercept included and l2_reg_strength = 0, hessp returns # res = (X, 1)' @ diag(h) @ (X, 1) @ s # = (X, 1)' @ (hX @ s[:n_features], sum(h) * s[-1]) # res[:n_features] = X' @ hX @ s[:n_features] + sum(h) * s[-1] @@ -296,7 +303,7 @@ def hessp(s): ret[:n_features] = X.T @ (hX @ s[:n_features]) else: ret[:n_features] = np.linalg.multi_dot([X.T, hX, s[:n_features]]) - ret[:n_features] += alpha * s[:n_features] + ret[:n_features] += l2_reg_strength * s[:n_features] if self.fit_intercept: ret[:n_features] += s[-1] * hh_intercept @@ -316,7 +323,7 @@ def hessp(s): n_threads=n_threads, ) grad = np.empty_like(coef.reshape(n_classes, -1), dtype=X.dtype) - grad[:, :n_features] = gradient.T @ X + alpha * w + grad[:, :n_features] = gradient.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient.sum(axis=0) @@ -353,7 +360,7 @@ def hessp(s): if sample_weight is not None: tmp *= sample_weight[:, np.newaxis] hess_prod = np.empty_like(grad) - hess_prod[:, :n_features] = tmp.T @ X + alpha * s + hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s if self.fit_intercept: hess_prod[:, -1] = tmp.sum(axis=0) return hess_prod.ravel() diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index b5c7dc54ac4de..0f9c8b8aec07c 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -525,9 +525,9 @@ def test_logistic_loss_and_grad(): fit_intercept=False, ) # First check that our derivation of the grad is correct - loss, grad = logloss.loss_gradient(w, X, y, alpha=alpha) + loss, grad = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) approx_grad = optimize.approx_fprime( - w, lambda w: logloss.loss(w, X, y, alpha=alpha), 1e-3 + w, lambda w: logloss.loss(w, X, y, l2_reg_strength=alpha), 1e-3 ) assert_array_almost_equal(grad, approx_grad, decimal=2) @@ -536,12 +536,14 @@ def test_logistic_loss_and_grad(): loss=HalfBinomialLoss(), fit_intercept=True, ) - loss_inter, grad_inter = logloss.loss_gradient(w, X_inter, y, alpha=alpha) + loss_inter, grad_inter = logloss.loss_gradient( + w, X_inter, y, l2_reg_strength=alpha + ) # Note, that intercept gets no L2 penalty. assert loss == pytest.approx(loss_inter + 0.5 * alpha * w[-1] ** 2) approx_grad = optimize.approx_fprime( - w, lambda w: logloss.loss(w, X_inter, y, alpha=alpha), 1e-3 + w, lambda w: logloss.loss(w, X_inter, y, l2_reg_strength=alpha), 1e-3 ) assert_array_almost_equal(grad_inter, approx_grad, decimal=2) @@ -565,9 +567,9 @@ def test_logistic_grad_hess(): # First check that gradients from gradient(), loss_gradient() and # gradient_hessp() are consistent - grad = logloss.gradient(w, X, y, alpha=alpha) - loss, grad_2 = logloss.loss_gradient(w, X, y, alpha=alpha) - grad_3, hessp = logloss.gradient_hessp(w, X, y, alpha=alpha) + grad = logloss.gradient(w, X, y, l2_reg_strength=alpha) + loss, grad_2 = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) + grad_3, hessp = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) assert_array_almost_equal(grad, grad_2) assert_array_almost_equal(grad, grad_3) @@ -583,7 +585,7 @@ def test_logistic_grad_hess(): e = 1e-3 d_x = np.linspace(-e, e, 30) d_grad = np.array( - [logloss.gradient(w + t * vector, X, y, alpha=alpha) for t in d_x] + [logloss.gradient(w + t * vector, X, y, l2_reg_strength=alpha) for t in d_x] ) d_grad -= d_grad.mean(axis=0) @@ -594,9 +596,9 @@ def test_logistic_grad_hess(): # Second check that our intercept implementation is good w = np.zeros(n_features + 1) logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) - loss_inter, grad_inter = logloss.loss_gradient(w, X, y, alpha=alpha) - loss_inter_2 = logloss.loss(w, X, y, alpha=alpha) - grad_inter_2, hess = logloss.gradient_hessp(w, X, y, alpha=alpha) + loss_inter, grad_inter = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) + loss_inter_2 = logloss.loss(w, X, y, l2_reg_strength=alpha) + grad_inter_2, hess = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) assert_array_almost_equal(loss_inter, loss_inter_2) assert_array_almost_equal(grad_inter, grad_inter_2) @@ -736,8 +738,8 @@ def test_intercept_logistic_helper(): logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) alpha = 1.0 w = np.ones(n_features + 1) - grad_inter, hess_inter = logloss.gradient_hessp(w, X, y, alpha=alpha) - loss_inter = logloss.loss(w, X, y, alpha=alpha) + grad_inter, hess_inter = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) + loss_inter = logloss.loss(w, X, y, l2_reg_strength=alpha) # Do not fit intercept. This can be considered equivalent to adding # a feature vector of ones, i.e last column vector's elements are all one. @@ -746,8 +748,8 @@ def test_intercept_logistic_helper(): loss=HalfBinomialLoss(), fit_intercept=False, ) - grad, hessp = logloss.gradient_hessp(w, X_, y, alpha=alpha) - loss = logloss.loss(w, X_, y, alpha=alpha) + grad, hessp = logloss.gradient_hessp(w, X_, y, l2_reg_strength=alpha) + loss = logloss.loss(w, X_, y, l2_reg_strength=alpha) # In the fit_intercept=False case, the feature vector of ones is # penalized. This should be taken care of. @@ -1161,7 +1163,7 @@ def test_multinomial_grad_hess(): fit_intercept=False, ) grad, hessp = multinomial.gradient_hessp( - w, X, y, alpha=alpha, sample_weight=sample_weights + w, X, y, l2_reg_strength=alpha, sample_weight=sample_weights ) # extract first column of hessian matrix vec = np.zeros(n_features * n_classes) @@ -1175,7 +1177,7 @@ def test_multinomial_grad_hess(): d_grad = np.array( [ multinomial.gradient_hessp( - w + t * vec, X, y, alpha=alpha, sample_weight=sample_weights + w + t * vec, X, y, l2_reg_strength=alpha, sample_weight=sample_weights )[0] for t in d_x ] From 39ffe9d180bdfc531aa195843fcff5ddef432c50 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 16 Dec 2021 12:09:16 +0100 Subject: [PATCH 18/44] ENH respect F-contiguity --- sklearn/linear_model/_linear_loss.py | 94 +++++++++++++++------ sklearn/linear_model/_logistic.py | 13 ++- sklearn/linear_model/tests/test_logistic.py | 4 +- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 32d8d9774f6c0..acdc58471cf71 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -7,7 +7,7 @@ class LinearLoss: - """General class for loss functions with raw_prediction = X @ coef. + """General class for loss functions with raw_prediction = X @ coef + intercept. The loss is the sum of per sample losses and includes an L2 term:: @@ -28,19 +28,30 @@ class LinearLoss: n_dof = n_features if loss.is_multiclass: - coef.shape = (n_classes * n_dof,) + coef.shape = (n_classes, n_dof) or ravelled (n_classes * n_dof,) else: coef.shape = (n_dof,) The intercept term is at the end of the coef array: if loss.is_multiclass: - intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] + if coef.shape (n_classes, n_dof): + intercept = coef[:, -1] + if coef.shape (n_classes * n_dof,) + intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof] intercept.shape = (n_classes,) else: intercept = coef[-1] - Note: If the average loss per sample is wanted instead of the sum of the - loss per sample, one can simply use a rescaled sample_weight such that + Note: If coef has shape (n_classes * n_dof,), the 2d-array can be reconstructed as + + coef.reshape((n_classes, -1), order="F") + + The option order="F" makes coef[:, i] contiguous. This, in turn, makes the + coefficients without intercept, coef[:, :-1], contiguous and speeds up + matrix-vector computations. + + Note: If the average loss per sample is wanted instead of the sum of the loss per + sample, one can simply use a rescaled sample_weight such that sum(sample_weight) = 1. Parameters @@ -58,8 +69,11 @@ def _w_intercept_raw(self, coef, X): Parameters ---------- - coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. @@ -82,7 +96,10 @@ def _w_intercept_raw(self, coef, X): raw_prediction = X @ w + intercept else: # reshape to (n_classes, n_dof) - w = coef.reshape(self._loss.n_classes, -1) + if coef.ndim == 1: + w = coef.reshape((self._loss.n_classes, -1), order="F") + else: + w = coef if self.fit_intercept: intercept = w[:, -1] w = w[:, :-1] @@ -97,8 +114,11 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) Parameters ---------- - coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) @@ -137,8 +157,11 @@ def loss_gradient( Parameters ---------- - coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) @@ -155,8 +178,8 @@ def loss_gradient( loss : float Sum of losses per sample plus penalty. - gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) - The gradient of the loss as ravelled array. + gradient : ndarray of shape coef.shape + The gradient of the loss. """ n_features, n_classes = X.shape[1], self._loss.n_classes n_dof = n_features + self.fit_intercept @@ -179,12 +202,15 @@ def loss_gradient( return loss, grad else: loss += 0.5 * l2_reg_strength * squared_norm(w) - grad = np.empty((n_classes, n_dof), dtype=X.dtype) + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") # gradient.shape = (n_samples, n_classes) grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient_per_sample.sum(axis=0) - return loss, grad.ravel() + if coef.ndim == 1: + return loss, grad.ravel(order="F") + else: + return loss, grad def gradient( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 @@ -193,8 +219,11 @@ def gradient( Parameters ---------- - coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) @@ -208,8 +237,8 @@ def gradient( Returns ------- - gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) - The gradient of the loss as ravelled array. + gradient : ndarray of shape coef.shape + The gradient of the loss. """ n_features, n_classes = X.shape[1], self._loss.n_classes n_dof = n_features + self.fit_intercept @@ -229,12 +258,15 @@ def gradient( grad[-1] = gradient_per_sample.sum() return grad else: - grad = np.empty((n_classes, n_dof), dtype=X.dtype) + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") # gradient.shape = (n_samples, n_classes) grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient_per_sample.sum(axis=0) - return grad.ravel() + if coef.ndim == 1: + return grad.ravel(order="F") + else: + return grad def gradient_hessp( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 @@ -243,8 +275,11 @@ def gradient_hessp( Parameters ---------- - coef : ndarray of shape (n_dof,) or (n_classes * n_dof,) + coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,) Coefficients of a linear model. + If shape (n_classes * n_dof,), the classes of one feature are contiguous, + i.e. one reconstructs the 2d-array via + coef.reshape((n_classes, -1), order="F"). y : contiguous array of shape (n_samples,) Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) @@ -258,14 +293,15 @@ def gradient_hessp( Returns ------- - gradient : ndarray of shape (n_dof,) or (n_classes * n_dof) - The gradient of the loss as ravelled array. + gradient : ndarray of shape coef.shape + The gradient of the loss. hessp : callable Function that takes in a vector input of shape of gradient and and returns matrix-vector product with hessian. """ (n_samples, n_features), n_classes = X.shape, self._loss.n_classes + n_dof = n_features + self.fit_intercept w, intercept, raw_prediction = self._w_intercept_raw(coef, X) if not self._loss.is_multiclass: @@ -322,7 +358,7 @@ def hessp(s): sample_weight=sample_weight, n_threads=n_threads, ) - grad = np.empty_like(coef.reshape(n_classes, -1), dtype=X.dtype) + grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") grad[:, :n_features] = gradient.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient.sum(axis=0) @@ -348,7 +384,7 @@ def hessp(s): # # See also https://github.com/scikit-learn/scikit-learn/pull/3646#discussion_r17461411 # noqa def hessp(s): - s = s.reshape(n_classes, -1) # shape = (n_classes, n_dof) + s = s.reshape((n_classes, -1), order="F") # shape = (n_classes, n_dof) if self.fit_intercept: s_intercept = s[:, -1] s = s[:, :-1] @@ -359,10 +395,16 @@ def hessp(s): tmp *= proba # * p_i_k if sample_weight is not None: tmp *= sample_weight[:, np.newaxis] - hess_prod = np.empty_like(grad) + hess_prod = np.empty_like(grad, order="F") hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s if self.fit_intercept: hess_prod[:, -1] = tmp.sum(axis=0) - return hess_prod.ravel() + if coef.ndim == 1: + return hess_prod.ravel(order="F") + else: + return hess_prod - return grad.ravel(), hessp + if coef.ndim == 1: + return grad.ravel(order="F"), hessp + else: + return grad, hessp diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 12fdd7eb7e5b7..d78006e3bb213 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -389,10 +389,12 @@ def _logistic_regression_path( w0[:, : coef.shape[1]] = coef if multi_class == "multinomial": - # scipy.optimize.minimize and newton-cg accepts only - # ravelled parameters. if solver in ["lbfgs", "newton-cg"]: - w0 = w0.ravel() + # scipy.optimize.minimize and newton-cg accept only ravelled parameters, + # i.e. 1d-arrays. LinearLoss expects classes to be contiguous and + # reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F"). + # As w0 is F-contiguous, ravel(order="F") also avoids a copy. + w0 = w0.ravel(order="F") loss = LinearLoss( loss=HalfMultinomialLoss(n_classes=classes.size), fit_intercept=fit_intercept, @@ -507,7 +509,10 @@ def _logistic_regression_path( if multi_class == "multinomial": n_classes = max(2, classes.size) - multi_w0 = np.reshape(w0, (n_classes, -1)) + if solver in ["lbfgs", "newton-cg"]: + multi_w0 = np.reshape(w0, (n_classes, -1), order="F") + else: + multi_w0 = w0 if n_classes == 2: multi_w0 = multi_w0[1][np.newaxis, :] coefs.append(multi_w0.copy()) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 0f9c8b8aec07c..29429e29087b7 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -1155,7 +1155,7 @@ def test_multinomial_grad_hess(): X = rng.randn(n_samples, n_features) w = rng.rand(n_classes, n_features) y = np.argmax(np.dot(X, w.T), axis=1).astype(X.dtype) - w = w.ravel() + w = w.ravel(order="F") sample_weights = np.ones(X.shape[0]) alpha = 1.0 multinomial = LinearLoss( @@ -1183,7 +1183,7 @@ def test_multinomial_grad_hess(): ] ) d_grad -= d_grad.mean(axis=0) - approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() + approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel(order="F") assert_array_almost_equal(hess_col, approx_hess_col) From 357fe2b7fae35be4c1b22d0bc2ae952a071248c9 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 16 Dec 2021 13:09:57 +0100 Subject: [PATCH 19/44] TST fix sag tests --- sklearn/linear_model/tests/test_sag.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index 935936affe408..0a1eba14c010d 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -938,11 +938,10 @@ def test_multinomial_loss(): loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) - weights_intercept = np.vstack((weights, intercept)).T.ravel() + weights_intercept = np.vstack((weights, intercept)).T loss_2, grad_2 = loss.loss_gradient( - weights_intercept, X, y, alpha=0.0, sample_weight=sample_weights + weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights ) - grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T # comparison @@ -973,11 +972,10 @@ def test_multinomial_loss_ground_truth(): loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) - weights_intercept = np.vstack((weights, intercept)).T.ravel() + weights_intercept = np.vstack((weights, intercept)).T loss_2, grad_2 = loss.loss_gradient( - weights_intercept, X, y, alpha=0.0, sample_weight=sample_weights + weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights ) - grad_2 = grad_2.reshape(n_classes, -1) grad_2 = grad_2[:, :-1].T assert_almost_equal(loss_1, loss_2) From 09fe29db462a285a4f3fc25453babbcd51816a46 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 16 Dec 2021 13:13:46 +0100 Subject: [PATCH 20/44] CLN rename to LinearModelLoss --- sklearn/linear_model/_linear_loss.py | 2 +- sklearn/linear_model/_logistic.py | 10 +++++----- sklearn/linear_model/tests/test_logistic.py | 16 ++++++++-------- sklearn/linear_model/tests/test_sag.py | 6 +++--- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index acdc58471cf71..707bfc6df186c 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -6,7 +6,7 @@ from ..utils.extmath import squared_norm -class LinearLoss: +class LinearModelLoss: """General class for loss functions with raw_prediction = X @ coef + intercept. The loss is the sum of per sample losses and includes an L2 term:: diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index d78006e3bb213..5438bf4454216 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -18,7 +18,7 @@ from joblib import Parallel, effective_n_jobs from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator -from ._linear_loss import LinearLoss +from ._linear_loss import LinearModelLoss from ._sag import sag_solver from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..preprocessing import LabelEncoder, LabelBinarizer @@ -391,11 +391,11 @@ def _logistic_regression_path( if multi_class == "multinomial": if solver in ["lbfgs", "newton-cg"]: # scipy.optimize.minimize and newton-cg accept only ravelled parameters, - # i.e. 1d-arrays. LinearLoss expects classes to be contiguous and + # i.e. 1d-arrays. LinearModelLoss expects classes to be contiguous and # reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F"). # As w0 is F-contiguous, ravel(order="F") also avoids a copy. w0 = w0.ravel(order="F") - loss = LinearLoss( + loss = LinearModelLoss( loss=HalfMultinomialLoss(n_classes=classes.size), fit_intercept=fit_intercept, ) @@ -410,10 +410,10 @@ def _logistic_regression_path( else: target = y_bin if solver == "lbfgs": - loss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + loss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) func = loss.loss_gradient elif solver == "newton-cg": - loss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + loss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) func = loss.loss grad = loss.gradient hess = loss.gradient_hessp # hess = [gradient, hessp] diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 29429e29087b7..77ce28432fb30 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -27,7 +27,7 @@ from sklearn.utils._testing import skip_if_no_parallel from sklearn.exceptions import ConvergenceWarning -from sklearn.linear_model._linear_loss import LinearLoss +from sklearn.linear_model._linear_loss import LinearModelLoss from sklearn.linear_model._logistic import ( _log_reg_scoring_path, _logistic_regression_path, @@ -520,7 +520,7 @@ def test_logistic_loss_and_grad(): # make an intercept of 0.5 w[-1] = 0.5 - logloss = LinearLoss( + logloss = LinearModelLoss( loss=HalfBinomialLoss(), fit_intercept=False, ) @@ -532,7 +532,7 @@ def test_logistic_loss_and_grad(): assert_array_almost_equal(grad, approx_grad, decimal=2) # Second check that our intercept implementation is good - logloss = LinearLoss( + logloss = LinearModelLoss( loss=HalfBinomialLoss(), fit_intercept=True, ) @@ -563,7 +563,7 @@ def test_logistic_grad_hess(): X_sp = sp.csr_matrix(X_sp) for X in (X_ref, X_sp): w = np.full(n_features, 0.1) - logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=False) + logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=False) # First check that gradients from gradient(), loss_gradient() and # gradient_hessp() are consistent @@ -595,7 +595,7 @@ def test_logistic_grad_hess(): # Second check that our intercept implementation is good w = np.zeros(n_features + 1) - logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) + logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=True) loss_inter, grad_inter = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) loss_inter_2 = logloss.loss(w, X, y, l2_reg_strength=alpha) grad_inter_2, hess = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) @@ -735,7 +735,7 @@ def test_intercept_logistic_helper(): y = y.astype(np.float64) # Fit intercept case. - logloss = LinearLoss(loss=HalfBinomialLoss(), fit_intercept=True) + logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=True) alpha = 1.0 w = np.ones(n_features + 1) grad_inter, hess_inter = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) @@ -744,7 +744,7 @@ def test_intercept_logistic_helper(): # Do not fit intercept. This can be considered equivalent to adding # a feature vector of ones, i.e last column vector's elements are all one. X_ = np.hstack((X, np.ones(n_samples)[:, np.newaxis])) - logloss = LinearLoss( + logloss = LinearModelLoss( loss=HalfBinomialLoss(), fit_intercept=False, ) @@ -1158,7 +1158,7 @@ def test_multinomial_grad_hess(): w = w.ravel(order="F") sample_weights = np.ones(X.shape[0]) alpha = 1.0 - multinomial = LinearLoss( + multinomial = LinearModelLoss( loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=False, ) diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index 0a1eba14c010d..de85638f66b06 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -11,7 +11,7 @@ from scipy.special import logsumexp from sklearn._loss.loss import HalfMultinomialLoss -from sklearn.linear_model._linear_loss import LinearLoss +from sklearn.linear_model._linear_loss import LinearModelLoss from sklearn.linear_model._sag import get_auto_step_size from sklearn.linear_model._sag_fast import _multinomial_grad_loss_all_samples from sklearn.linear_model import LogisticRegression, Ridge @@ -934,7 +934,7 @@ def test_multinomial_loss(): dataset, weights, intercept, n_samples, n_features, n_classes ) # compute loss and gradient like in multinomial LogisticRegression - loss = LinearLoss( + loss = LinearModelLoss( loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) @@ -968,7 +968,7 @@ def test_multinomial_loss_ground_truth(): diff = sample_weights[:, np.newaxis] * (np.exp(p) - Y_bin) grad_1 = np.dot(X.T, diff) - loss = LinearLoss( + loss = LinearModelLoss( loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) From a4b1b6be3f8e75b90d3d2a9a3e651cb472d74f10 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 16 Dec 2021 13:45:36 +0100 Subject: [PATCH 21/44] CLN improve comments according to review --- doc/whats_new/v1.1.rst | 6 +++--- sklearn/linear_model/_linear_loss.py | 5 +++-- sklearn/linear_model/_logistic.py | 6 +++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 19f525da4af64..446bb4f1af8a7 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -240,9 +240,9 @@ Changelog - |Enhancement| :class:`~linear_model.LogisticRegression` is faster for ``solvers="lbfgs"`` and ``solver="newton-cg"``, for binary and in particular for multiclass problems thanks to the new private loss function module. In the multiclass - case, also the memory consumption is reduced for these solvers as the target is now - label encoded (mapped to integers) instead of label binarized (one-hot encoded). The - more classes, the larger the benefit. + case, the memory consumption has also been reduced for these solvers as the target is + now label encoded (mapped to integers) instead of label binarized (one-hot encoded). + The more classes, the larger the benefit. :pr:`21808`, :pr:`20567` and :pr:`21814` by :user:`Christian Lorentzen `. diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 707bfc6df186c..b992c87e65a16 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -364,12 +364,13 @@ def hessp(s): grad[:, -1] = gradient.sum(axis=0) # Full hessian-vector product, i.e. not only the diagonal part of the - # hessian. Derivation with some index battle for inupt vector s: + # hessian. Derivation with some index battle for input vector s: # - sample index i # - feature indices j, m # - class indices k, l # - 1_{k=l} is one if k=l else 0 - # - p_i_k is class probability of sample i and class k + # - p_i_k is the (predicted) probability that sample i belongs to class k + # for all i: sum_k p_i_k = 1 # - s_l_m is input vector for class l and feature m # - X' = X transposed # diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 5438bf4454216..60c5053050cb1 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -319,6 +319,8 @@ def _logistic_regression_path( mask = y == pos_class y_bin = np.ones(y.shape, dtype=X.dtype) if solver in ["lbfgs", "newton-cg"]: + # HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead + # of in [-1, 1]. mask_classes = np.array([0, 1]) y_bin[~mask] = 0.0 else: @@ -335,7 +337,9 @@ def _logistic_regression_path( else: if solver in ["sag", "saga", "lbfgs", "newton-cg"]: # SAG, lbfgs and newton-cg multinomial solvers need LabelEncoder, - # not LabelBinarizer, i.e. y is mapped to integers. + # not LabelBinarizer, i.e. y as a 1d-array of integers. + # LabelEncoder also saves memory compared to LabelBinarizer, especially + # when n_classes is large. le = LabelEncoder() Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) else: From 8dfe1658717b4c3af5bbc189110832133905d316 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 16 Dec 2021 13:49:20 +0100 Subject: [PATCH 22/44] CLN liblinear comment --- sklearn/linear_model/_logistic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 60c5053050cb1..3caab6c435ef3 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -343,7 +343,7 @@ def _logistic_regression_path( le = LabelEncoder() Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) else: - # Apply LabelBinarizer, i.e. y is one-hot encoded. + # For liblinear solver, apply LabelBinarizer, i.e. y is one-hot encoded. lbin = LabelBinarizer() Y_multi = lbin.fit_transform(y) if Y_multi.shape[1] == 1: From fa249c34e6d4e8014253bdf3cdeea19505e6c56e Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 17 Dec 2021 11:14:38 +0100 Subject: [PATCH 23/44] TST add / move test to test_linear_loss.py --- .../linear_model/tests/test_linear_loss.py | 302 ++++++++++++++++++ sklearn/linear_model/tests/test_logistic.py | 202 +----------- 2 files changed, 310 insertions(+), 194 deletions(-) create mode 100644 sklearn/linear_model/tests/test_linear_loss.py diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py new file mode 100644 index 0000000000000..97bedfb896a88 --- /dev/null +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -0,0 +1,302 @@ +""" +Tests for LinearModelLoss + +Note that correctness of losses is already well covered in the _loss module. +""" +import pytest +import numpy as np +from numpy.testing import assert_allclose +from scipy import linalg, optimize, sparse + +from sklearn._loss.loss import ( + HalfBinomialLoss, + HalfMultinomialLoss, + HalfPoissonLoss, +) +from sklearn.datasets import make_low_rank_matrix +from sklearn.linear_model._linear_loss import LinearModelLoss +from sklearn.utils.extmath import squared_norm + + +# We don not need to test all losses, just what LinearModelLoss does on top of the +# base losses. +LOSSES = [HalfBinomialLoss, HalfMultinomialLoss, HalfPoissonLoss] + + +def random_X_y_coef( + linear_model_loss, n_samples, n_features, coef_bound=(-2, 2), seed=42 +): + """Random generate y, X and coef in valid range.""" + rng = np.random.RandomState(seed) + n_dof = n_features + linear_model_loss.fit_intercept + X = make_low_rank_matrix( + n_samples=n_samples, + n_features=n_features, + random_state=rng, + ) + + if linear_model_loss._loss.is_multiclass: + n_classes = linear_model_loss._loss.n_classes + coef = np.empty((n_classes, n_dof)) + coef.flat[:] = rng.uniform( + low=coef_bound[0], + high=coef_bound[1], + size=n_classes * n_dof, + ) + if linear_model_loss.fit_intercept: + raw_prediction = X @ coef[:, :-1].T + coef[:, -1] + else: + raw_prediction = X @ coef.T + proba = linear_model_loss._loss.link.inverse(raw_prediction) + + # y = rng.choice(np.arange(n_classes), p=proba) does not work, see + # See https://stackoverflow.com/a/34190035/16761084 + def choice_vectorized(items, p): + s = p.cumsum(axis=1) + r = np.random.rand(p.shape[0])[:, None] + k = (s < r).sum(axis=1) + return items[k] + + y = choice_vectorized(np.arange(n_classes), p=proba).astype(np.float64) + else: + coef = np.empty((n_dof,)) + coef.flat[:] = rng.uniform( + low=coef_bound[0], + high=coef_bound[1], + size=n_dof, + ) + if linear_model_loss.fit_intercept: + raw_prediction = X @ coef[:-1] + coef[-1] + else: + raw_prediction = X @ coef + y = linear_model_loss._loss.link.inverse( + raw_prediction + rng.uniform(low=-1, high=1, size=n_samples) + ) + + return X, y, coef + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +def test_loss_gradients_are_the_same( + base_loss, fit_intercept, sample_weight, l2_reg_strength +): + """Test that loss and gradient are the same across different functions.""" + loss = LinearModelLoss(loss=base_loss(), fit_intercept=fit_intercept) + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=10, n_features=5, seed=42 + ) + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + l1 = loss.loss( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g1 = loss.gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l2, g2 = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g3, h3 = loss.gradient_hessp( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + assert_allclose(l1, l2) + assert_allclose(g1, g2) + assert_allclose(g1, g3) + + # same for sparse X + X = sparse.csr_matrix(X) + l1_sp = loss.loss( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g1_sp = loss.gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l2_sp, g2_sp = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + g3_sp, h3_sp = loss.gradient_hessp( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + assert_allclose(l1, l1_sp) + assert_allclose(l1, l2_sp) + assert_allclose(g1, g1_sp) + assert_allclose(g1, g2_sp) + assert_allclose(g1, g3_sp) + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +@pytest.mark.parametrize("X_sparse", [False, True]) +def test_loss_gradients_hessp_intercept( + base_loss, sample_weight, l2_reg_strength, X_sparse +): + """Test that loss and gradient handle intercept correctly.""" + loss = LinearModelLoss(loss=base_loss(), fit_intercept=False) + loss_inter = LinearModelLoss(loss=base_loss(), fit_intercept=True) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + + X[:, -1] = 1 # make last column of 1 to mimic intercept term + X_inter = X[ + :, :-1 + ] # exclude intercept column as it is added automatically by loss_inter + + if X_sparse: + X = sparse.csr_matrix(X) + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + l, g = loss.loss_gradient( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + _, hessp = loss.gradient_hessp( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + l_inter, g_inter = loss_inter.loss_gradient( + coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + _, hessp_inter = loss_inter.gradient_hessp( + coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + + # Note, that intercept gets no L2 penalty. + assert l == pytest.approx( + l_inter + 0.5 * l2_reg_strength * squared_norm(coef.T[-1]) + ) + + g_inter_corrected = g_inter + g_inter_corrected.T[-1] += l2_reg_strength * coef.T[-1] + assert_allclose(g, g_inter_corrected) + + s = np.random.RandomState(42).randn(*coef.shape) + h = hessp(s) + h_inter = hessp_inter(s) + h_inter_corrected = h_inter + h_inter_corrected.T[-1] += l2_reg_strength * s.T[-1] + assert_allclose(h, h_inter_corrected) + + +@pytest.mark.parametrize("base_loss", LOSSES) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("sample_weight", [None, "range"]) +@pytest.mark.parametrize("l2_reg_strength", [0, 1]) +def test_gradients_hessians_numerically( + base_loss, fit_intercept, sample_weight, l2_reg_strength +): + """Test gradients and hessians with numerical derivatives. + + Gradient should equal the numerical derivatives of the loss function. + Hessians should equal the numerical derivatives of gradients. + """ + loss = LinearModelLoss(loss=base_loss(), fit_intercept=fit_intercept) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + coef = coef.ravel(order="F") # this is important only for multinomial loss + + if sample_weight == "range": + sample_weight = np.linspace(1, y.shape[0], num=y.shape[0]) + + # 1. Check gradients numerically + eps = 1e-6 + g, hessp = loss.gradient_hessp( + coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength + ) + # Use a trick to get central finte difference of accuracy 4 (five-point stencil) + # https://en.wikipedia.org/wiki/Numerical_differentiation + # https://en.wikipedia.org/wiki/Finite_difference_coefficient + approx_g1 = optimize.approx_fprime( + coef, + lambda coef: loss.loss( + coef - eps, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ), + 2 * eps, + ) # (f(x + eps) - f(x - eps)) / (2*eps) + approx_g2 = optimize.approx_fprime( + coef, + lambda coef: loss.loss( + coef - 2 * eps, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ), + 4 * eps, + ) # (f(x + 2*eps) - f(x - 2*eps)) / (4*eps) + approx_g = 4 / 3 * approx_g1 - 1 / 3 * approx_g2 + assert_allclose(g, approx_g, rtol=1e-2, atol=1e-8) + + # 2. Check hessp numerically along the second direction of the gradient + vector = np.zeros_like(g) + vector[1] = 1 + hess_col = hessp(vector) + # Computation of the Hessian is particularly fragile to numerical errors when doing + # simple finite differences. Here we compute the grad along a path in the direction + # of the vector and then use a least-square regression to estimate the slope + eps = 1e-3 + d_x = np.linspace(-eps, eps, 30) + d_grad = np.array( + [ + loss.gradient( + coef + t * vector, + X, + y, + sample_weight=sample_weight, + l2_reg_strength=l2_reg_strength, + ) + for t in d_x + ] + ) + d_grad -= d_grad.mean(axis=0) + approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() + assert_allclose(approx_hess_col, hess_col, rtol=1e-3) + + +@pytest.mark.parametrize("fit_intercept", [False, True]) +def test_multinomial_coef_shape(fit_intercept): + """Test that multinomial LinearModelLoss respects shape of coef.""" + loss = LinearModelLoss(loss=HalfMultinomialLoss(), fit_intercept=fit_intercept) + n_samples, n_features = 10, 5 + X, y, coef = random_X_y_coef( + linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 + ) + s = np.random.RandomState(42).randn(*coef.shape) + + l, g = loss.loss_gradient(coef, X, y) + g1 = loss.gradient(coef, X, y) + g2, hessp = loss.gradient_hessp(coef, X, y) + h = hessp(s) + assert g.shape == coef.shape + assert h.shape == coef.shape + assert_allclose(g, g1) + assert_allclose(g, g2) + + coef_r = coef.ravel(order="F") + s_r = s.ravel(order="F") + l_r, g_r = loss.loss_gradient(coef_r, X, y) + g1_r = loss.gradient(coef_r, X, y) + g2_r, hessp_r = loss.gradient_hessp(coef_r, X, y) + h_r = hessp_r(s_r) + assert g_r.shape == coef_r.shape + assert h_r.shape == coef_r.shape + assert_allclose(g_r, g1_r) + assert_allclose(g_r, g2_r) + + assert_allclose(g, g_r.reshape(loss._loss.n_classes, -1, order="F")) + assert_allclose(h, h_r.reshape(loss._loss.n_classes, -1, order="F")) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 77ce28432fb30..80befb5ffe645 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -4,12 +4,10 @@ import numpy as np from numpy.testing import assert_allclose, assert_almost_equal from numpy.testing import assert_array_almost_equal, assert_array_equal -import scipy.sparse as sp -from scipy import linalg, optimize, sparse +from scipy import sparse import pytest -from sklearn._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from sklearn.base import clone from sklearn.datasets import load_iris, make_classification from sklearn.metrics import log_loss @@ -27,7 +25,6 @@ from sklearn.utils._testing import skip_if_no_parallel from sklearn.exceptions import ConvergenceWarning -from sklearn.linear_model._linear_loss import LinearModelLoss from sklearn.linear_model._logistic import ( _log_reg_scoring_path, _logistic_regression_path, @@ -36,7 +33,7 @@ ) X = [[-1, 0], [0, 1], [1, 1]] -X_sp = sp.csr_matrix(X) +X_sp = sparse.csr_matrix(X) Y1 = [0, 1, 1] Y2 = [2, 1, 0] iris = load_iris() @@ -314,10 +311,10 @@ def test_sparsify(): pred_d_d = clf.decision_function(iris.data) clf.sparsify() - assert sp.issparse(clf.coef_) + assert sparse.issparse(clf.coef_) pred_s_d = clf.decision_function(iris.data) - sp_data = sp.coo_matrix(iris.data) + sp_data = sparse.coo_matrix(iris.data) pred_s_s = clf.decision_function(sp_data) clf.densify() @@ -498,111 +495,6 @@ def test_liblinear_dual_random_state(): assert_array_almost_equal(lr1.coef_, lr3.coef_) -def test_logistic_loss_and_grad(): - n_samples, n_features = 20, 20 - alpha = 1.0 - X_ref, y = make_classification( - n_samples=n_samples, n_features=n_features, random_state=0 - ) - # make last column of 1 to mimic intercept term - X_ref[:, -1] = 1 - X_ref_inter = X_ref[:, :-1] # exclude intercept column - y = y.astype(np.float64) - n_features = X_ref.shape[1] - - X_sp = X_ref.copy() - X_sp[X_sp < 0.1] = 0 - X_sp = sp.csr_matrix(X_sp) - X_sp_inter = sp.lil_matrix(X_sp) # supports slicing - X_sp_inter = sp.csr_matrix(X_sp_inter[:, :-1]) - for X, X_inter in ((X_ref, X_ref_inter), (X_sp, X_sp_inter)): - w = np.ones(n_features) - # make an intercept of 0.5 - w[-1] = 0.5 - - logloss = LinearModelLoss( - loss=HalfBinomialLoss(), - fit_intercept=False, - ) - # First check that our derivation of the grad is correct - loss, grad = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) - approx_grad = optimize.approx_fprime( - w, lambda w: logloss.loss(w, X, y, l2_reg_strength=alpha), 1e-3 - ) - assert_array_almost_equal(grad, approx_grad, decimal=2) - - # Second check that our intercept implementation is good - logloss = LinearModelLoss( - loss=HalfBinomialLoss(), - fit_intercept=True, - ) - loss_inter, grad_inter = logloss.loss_gradient( - w, X_inter, y, l2_reg_strength=alpha - ) - # Note, that intercept gets no L2 penalty. - assert loss == pytest.approx(loss_inter + 0.5 * alpha * w[-1] ** 2) - - approx_grad = optimize.approx_fprime( - w, lambda w: logloss.loss(w, X_inter, y, l2_reg_strength=alpha), 1e-3 - ) - assert_array_almost_equal(grad_inter, approx_grad, decimal=2) - - -def test_logistic_grad_hess(): - rng = np.random.RandomState(0) - n_samples, n_features = 50, 5 - alpha = 1.0 - X_ref = rng.randn(n_samples, n_features) - y = np.sign(X_ref.dot(5 * rng.randn(n_features))) - X_ref -= X_ref.mean() - X_ref /= X_ref.std() - # make last column of 1 to mimic intercept term - X_ref[:, :-1] = 1 - X_sp = X_ref.copy() - X_sp[X_sp < 0.1] = 0 - X_sp = sp.csr_matrix(X_sp) - for X in (X_ref, X_sp): - w = np.full(n_features, 0.1) - logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=False) - - # First check that gradients from gradient(), loss_gradient() and - # gradient_hessp() are consistent - grad = logloss.gradient(w, X, y, l2_reg_strength=alpha) - loss, grad_2 = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) - grad_3, hessp = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) - assert_array_almost_equal(grad, grad_2) - assert_array_almost_equal(grad, grad_3) - - # Now check our hessian along the second direction of the grad - vector = np.zeros_like(grad) - vector[1] = 1 - hess_col = hessp(vector) - - # Computation of the Hessian is particularly fragile to numerical - # errors when doing simple finite differences. Here we compute the - # grad along a path in the direction of the vector and then use a - # least-square regression to estimate the slope - e = 1e-3 - d_x = np.linspace(-e, e, 30) - d_grad = np.array( - [logloss.gradient(w + t * vector, X, y, l2_reg_strength=alpha) for t in d_x] - ) - - d_grad -= d_grad.mean(axis=0) - approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel() - - assert_array_almost_equal(approx_hess_col, hess_col, decimal=3) - - # Second check that our intercept implementation is good - w = np.zeros(n_features + 1) - logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=True) - loss_inter, grad_inter = logloss.loss_gradient(w, X, y, l2_reg_strength=alpha) - loss_inter_2 = logloss.loss(w, X, y, l2_reg_strength=alpha) - grad_inter_2, hess = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) - assert_array_almost_equal(loss_inter, loss_inter_2) - assert_array_almost_equal(grad_inter, grad_inter_2) - - def test_logistic_cv(): # test for LogisticRegressionCV object n_samples, n_features = 50, 5 @@ -716,7 +608,7 @@ def test_multinomial_logistic_regression_string_inputs(): def test_logistic_cv_sparse(): X, y = make_classification(n_samples=50, n_features=5, random_state=0) X[X < 1.0] = 0.0 - csr = sp.csr_matrix(X) + csr = sparse.csr_matrix(X) clf = LogisticRegressionCV() clf.fit(X, y) @@ -727,46 +619,6 @@ def test_logistic_cv_sparse(): assert clfs.C_ == clf.C_ -def test_intercept_logistic_helper(): - n_samples, n_features = 10, 5 - X, y = make_classification( - n_samples=n_samples, n_features=n_features, random_state=0 - ) - y = y.astype(np.float64) - - # Fit intercept case. - logloss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=True) - alpha = 1.0 - w = np.ones(n_features + 1) - grad_inter, hess_inter = logloss.gradient_hessp(w, X, y, l2_reg_strength=alpha) - loss_inter = logloss.loss(w, X, y, l2_reg_strength=alpha) - - # Do not fit intercept. This can be considered equivalent to adding - # a feature vector of ones, i.e last column vector's elements are all one. - X_ = np.hstack((X, np.ones(n_samples)[:, np.newaxis])) - logloss = LinearModelLoss( - loss=HalfBinomialLoss(), - fit_intercept=False, - ) - grad, hessp = logloss.gradient_hessp(w, X_, y, l2_reg_strength=alpha) - loss = logloss.loss(w, X_, y, l2_reg_strength=alpha) - - # In the fit_intercept=False case, the feature vector of ones is - # penalized. This should be taken care of. - assert_almost_equal(loss_inter + 0.5 * (w[-1] ** 2), loss) - - # Check gradient. - assert_array_almost_equal(grad_inter[:n_features], grad[:n_features]) - assert_almost_equal(grad_inter[-1] + alpha * w[-1], grad[-1]) - - rng = np.random.RandomState(0) - grad = rng.rand(n_features + 1) - hess_interp = hess_inter(grad) - hess = hessp(grad) - assert_array_almost_equal(hess_interp[:n_features], hess[:n_features]) - assert_almost_equal(hess_interp[-1] + alpha * grad[-1], hess[-1]) - - def test_ovr_multinomial_iris(): # Test that OvR and multinomial are correct using the iris dataset. train, target = iris.data, iris.target @@ -1149,44 +1001,6 @@ def test_logistic_regression_multinomial(): assert_allclose(clf_path.intercept_, ref_i.intercept_, rtol=2e-2) -def test_multinomial_grad_hess(): - rng = np.random.RandomState(0) - n_samples, n_features, n_classes = 100, 5, 3 - X = rng.randn(n_samples, n_features) - w = rng.rand(n_classes, n_features) - y = np.argmax(np.dot(X, w.T), axis=1).astype(X.dtype) - w = w.ravel(order="F") - sample_weights = np.ones(X.shape[0]) - alpha = 1.0 - multinomial = LinearModelLoss( - loss=HalfMultinomialLoss(n_classes=n_classes), - fit_intercept=False, - ) - grad, hessp = multinomial.gradient_hessp( - w, X, y, l2_reg_strength=alpha, sample_weight=sample_weights - ) - # extract first column of hessian matrix - vec = np.zeros(n_features * n_classes) - vec[0] = 1 - hess_col = hessp(vec) - - # Estimate hessian using least squares as done in - # test_logistic_grad_hess - e = 1e-3 - d_x = np.linspace(-e, e, 30) - d_grad = np.array( - [ - multinomial.gradient_hessp( - w + t * vec, X, y, l2_reg_strength=alpha, sample_weight=sample_weights - )[0] - for t in d_x - ] - ) - d_grad -= d_grad.mean(axis=0) - approx_hess_col = linalg.lstsq(d_x[:, np.newaxis], d_grad)[0].ravel(order="F") - assert_array_almost_equal(hess_col, approx_hess_col) - - def test_liblinear_decision_function_zero(): # Test negative prediction when decision_function values are zero. # Liblinear predicts the positive class when decision_function values @@ -1578,8 +1392,8 @@ def test_dtype_match(solver, multi_class, fit_intercept): y_32 = np.array(Y1).astype(np.float32) X_64 = np.array(X).astype(np.float64) y_64 = np.array(Y1).astype(np.float64) - X_sparse_32 = sp.csr_matrix(X, dtype=np.float32) - X_sparse_64 = sp.csr_matrix(X, dtype=np.float64) + X_sparse_32 = sparse.csr_matrix(X, dtype=np.float32) + X_sparse_64 = sparse.csr_matrix(X, dtype=np.float64) solver_tol = 5e-4 lr_templ = LogisticRegression( @@ -2281,7 +2095,7 @@ def test_large_sparse_matrix(solver): # Non-regression test for pull-request #21093. # generate sparse matrix with int64 indices - X = sp.rand(20, 10, format="csr") + X = sparse.rand(20, 10, format="csr") for attr in ["indices", "indptr"]: setattr(X, attr, getattr(X, attr).astype("int64")) y = np.random.randint(2, size=X.shape[0]) From dea9bf099517ff85e767fdcc893fc8d24f6e521f Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 17 Dec 2021 11:29:39 +0100 Subject: [PATCH 24/44] CLN comment placement --- sklearn/linear_model/_logistic.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 3caab6c435ef3..37814a1aaf345 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -1216,10 +1216,19 @@ def fit(self, X, y, sample_weight=None): prefer = "threads" else: prefer = "processes" - if solver in ["lbfgs", "newton-cg"] and len(classes_) == 1: + + # TODO: Refactor this to avoid joblib parallelism entirely when doing binary + # and multinomial multiclass classification and use joblib only for the + # one-vs-rest multiclass case. + if ( + solver in ["lbfgs", "newton-cg"] + and len(classes_) == 1 + and effective_n_jobs(self.n_jobs) == 1 + ): n_threads = _openmp_effective_n_threads() else: n_threads = 1 + fold_coefs_ = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, From 3731fc736d6da383de8c094ef48521be78776219 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 17 Dec 2021 14:28:40 +0100 Subject: [PATCH 25/44] trigger CI From 1e56ae7e6532e3449fd15f90afc589153321e56e Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 18 Dec 2021 17:58:58 +0100 Subject: [PATCH 26/44] CLN add comment about contiguity of raw_prediction --- sklearn/linear_model/_linear_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index b992c87e65a16..478a97a60159d 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -105,7 +105,7 @@ def _w_intercept_raw(self, coef, X): w = w[:, :-1] else: intercept = 0.0 - raw_prediction = X @ w.T + intercept + raw_prediction = X @ w.T + intercept # ndarray, likely C-contiguous return w, intercept, raw_prediction From ef0b98f23251d1b2b0bd8801e456f258392a8d18 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 18 Dec 2021 20:03:21 +0100 Subject: [PATCH 27/44] DEBUG debian-32 --- azure-pipelines.yml | 388 +++++++++++++++---------------- build_tools/azure/test_script.sh | 2 +- 2 files changed, 195 insertions(+), 195 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 2db2aafb8cc95..d92d2831a8e2d 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -57,159 +57,159 @@ jobs: mypy sklearn/ displayName: Run mypy -- template: build_tools/azure/posix.yml - parameters: - name: Linux_Nightly - vmImage: ubuntu-20.04 - dependsOn: [git_commit, linting] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - or(eq(variables['Build.Reason'], 'Schedule'), - contains(dependencies['git_commit']['outputs']['commit.message'], '[scipy-dev]' - ) - ) - ) - matrix: - pylatest_pip_scipy_dev: - DISTRIB: 'conda-pip-scipy-dev' - PYTHON_VERSION: '*' - CHECK_WARNINGS: 'true' - CHECK_PYTEST_SOFT_DEPENDENCY: 'true' - TEST_DOCSTRINGS: 'true' - # Tests that require large downloads over the networks are skipped in CI. - # Here we make sure, that they are still run on a regular basis. - SKLEARN_SKIP_NETWORK_TESTS: '0' - CREATE_ISSUE_ON_TRACKER: 'true' +# - template: build_tools/azure/posix.yml +# parameters: +# name: Linux_Nightly +# vmImage: ubuntu-20.04 +# dependsOn: [git_commit, linting] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# or(eq(variables['Build.Reason'], 'Schedule'), +# contains(dependencies['git_commit']['outputs']['commit.message'], '[scipy-dev]' +# ) +# ) +# ) +# matrix: +# pylatest_pip_scipy_dev: +# DISTRIB: 'conda-pip-scipy-dev' +# PYTHON_VERSION: '*' +# CHECK_WARNINGS: 'true' +# CHECK_PYTEST_SOFT_DEPENDENCY: 'true' +# TEST_DOCSTRINGS: 'true' +# # Tests that require large downloads over the networks are skipped in CI. +# # Here we make sure, that they are still run on a regular basis. +# SKLEARN_SKIP_NETWORK_TESTS: '0' +# CREATE_ISSUE_ON_TRACKER: 'true' -# Check compilation with intel C++ compiler (ICC) -- template: build_tools/azure/posix.yml - parameters: - name: Linux_Nightly_ICC - vmImage: ubuntu-20.04 - dependsOn: [git_commit, linting] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - or(eq(variables['Build.Reason'], 'Schedule'), - contains(dependencies['git_commit']['outputs']['commit.message'], '[icc-build]') - ) - ) - matrix: - pylatest_conda_forge_mkl: - DISTRIB: 'conda' - CONDA_CHANNEL: 'conda-forge' - PYTHON_VERSION: '*' - BLAS: 'mkl' - COVERAGE: 'false' - BUILD_WITH_ICC: 'true' +# # Check compilation with intel C++ compiler (ICC) +# - template: build_tools/azure/posix.yml +# parameters: +# name: Linux_Nightly_ICC +# vmImage: ubuntu-20.04 +# dependsOn: [git_commit, linting] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# or(eq(variables['Build.Reason'], 'Schedule'), +# contains(dependencies['git_commit']['outputs']['commit.message'], '[icc-build]') +# ) +# ) +# matrix: +# pylatest_conda_forge_mkl: +# DISTRIB: 'conda' +# CONDA_CHANNEL: 'conda-forge' +# PYTHON_VERSION: '*' +# BLAS: 'mkl' +# COVERAGE: 'false' +# BUILD_WITH_ICC: 'true' -- template: build_tools/azure/posix-docker.yml - parameters: - name: Linux_Nightly_PyPy - vmImage: ubuntu-20.04 - dependsOn: [linting, git_commit] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - or( - eq(variables['Build.Reason'], 'Schedule'), - contains(dependencies['git_commit']['outputs']['commit.message'], '[pypy]') - ) - ) - matrix: - pypy3: - DISTRIB: 'conda-mamba-pypy3' - DOCKER_CONTAINER: 'condaforge/mambaforge-pypy3:4.10.3-5' - PILLOW_VERSION: 'none' - PANDAS_VERSION: 'none' - CREATE_ISSUE_ON_TRACKER: 'true' +# - template: build_tools/azure/posix-docker.yml +# parameters: +# name: Linux_Nightly_PyPy +# vmImage: ubuntu-20.04 +# dependsOn: [linting, git_commit] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# or( +# eq(variables['Build.Reason'], 'Schedule'), +# contains(dependencies['git_commit']['outputs']['commit.message'], '[pypy]') +# ) +# ) +# matrix: +# pypy3: +# DISTRIB: 'conda-mamba-pypy3' +# DOCKER_CONTAINER: 'condaforge/mambaforge-pypy3:4.10.3-5' +# PILLOW_VERSION: 'none' +# PANDAS_VERSION: 'none' +# CREATE_ISSUE_ON_TRACKER: 'true' -# Will run all the time regardless of linting outcome. -- template: build_tools/azure/posix.yml - parameters: - name: Linux_Runs - vmImage: ubuntu-20.04 - dependsOn: [git_commit] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')) - ) - matrix: - pylatest_conda_forge_mkl: - DISTRIB: 'conda' - CONDA_CHANNEL: 'conda-forge' - PYTHON_VERSION: '*' - BLAS: 'mkl' - COVERAGE: 'true' - SHOW_SHORT_SUMMARY: 'true' +# # Will run all the time regardless of linting outcome. +# - template: build_tools/azure/posix.yml +# parameters: +# name: Linux_Runs +# vmImage: ubuntu-20.04 +# dependsOn: [git_commit] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')) +# ) +# matrix: +# pylatest_conda_forge_mkl: +# DISTRIB: 'conda' +# CONDA_CHANNEL: 'conda-forge' +# PYTHON_VERSION: '*' +# BLAS: 'mkl' +# COVERAGE: 'true' +# SHOW_SHORT_SUMMARY: 'true' -# Check compilation with Ubuntu bionic 18.04 LTS and scipy from conda-forge -- template: build_tools/azure/posix.yml - parameters: - name: Ubuntu_Bionic - vmImage: ubuntu-18.04 - dependsOn: [git_commit, linting] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - ne(variables['Build.Reason'], 'Schedule') - ) - matrix: - py37_conda_forge_openblas_ubuntu_1804: - DISTRIB: 'conda' - CONDA_CHANNEL: 'conda-forge' - PYTHON_VERSION: '3.7' - BLAS: 'openblas' - COVERAGE: 'false' - BUILD_WITH_ICC: 'false' +# # Check compilation with Ubuntu bionic 18.04 LTS and scipy from conda-forge +# - template: build_tools/azure/posix.yml +# parameters: +# name: Ubuntu_Bionic +# vmImage: ubuntu-18.04 +# dependsOn: [git_commit, linting] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# ne(variables['Build.Reason'], 'Schedule') +# ) +# matrix: +# py37_conda_forge_openblas_ubuntu_1804: +# DISTRIB: 'conda' +# CONDA_CHANNEL: 'conda-forge' +# PYTHON_VERSION: '3.7' +# BLAS: 'openblas' +# COVERAGE: 'false' +# BUILD_WITH_ICC: 'false' -- template: build_tools/azure/posix.yml - parameters: - name: Linux - vmImage: ubuntu-20.04 - dependsOn: [linting, git_commit] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - ne(variables['Build.Reason'], 'Schedule') - ) - matrix: - # Linux environment to test that scikit-learn can be built against - # versions of numpy, scipy with ATLAS that comes with Ubuntu Focal 20.04 - # i.e. numpy 1.17.4 and scipy 1.3.3 - ubuntu_atlas: - DISTRIB: 'ubuntu' - JOBLIB_VERSION: 'min' - PANDAS_VERSION: 'none' - THREADPOOLCTL_VERSION: 'min' - COVERAGE: 'false' - # Linux + Python 3.7 build with OpenBLAS and without SITE_JOBLIB - py37_conda_defaults_openblas: - DISTRIB: 'conda' - CONDA_CHANNEL: 'defaults' # Anaconda main channel - PYTHON_VERSION: '3.7' - BLAS: 'openblas' - NUMPY_VERSION: 'min' - SCIPY_VERSION: 'min' - MATPLOTLIB_VERSION: 'min' - THREADPOOLCTL_VERSION: '2.2.0' - SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1' - # Linux environment to test the latest available dependencies and MKL. - # It runs tests requiring lightgbm, pandas and PyAMG. - pylatest_pip_openblas_pandas: - DISTRIB: 'conda-pip-latest' - PYTHON_VERSION: '3.9' - PANDAS_VERSION: 'none' - CHECK_PYTEST_SOFT_DEPENDENCY: 'true' - TEST_DOCSTRINGS: 'true' - CHECK_WARNINGS: 'true' +# - template: build_tools/azure/posix.yml +# parameters: +# name: Linux +# vmImage: ubuntu-20.04 +# dependsOn: [linting, git_commit] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# ne(variables['Build.Reason'], 'Schedule') +# ) +# matrix: +# # Linux environment to test that scikit-learn can be built against +# # versions of numpy, scipy with ATLAS that comes with Ubuntu Focal 20.04 +# # i.e. numpy 1.17.4 and scipy 1.3.3 +# ubuntu_atlas: +# DISTRIB: 'ubuntu' +# JOBLIB_VERSION: 'min' +# PANDAS_VERSION: 'none' +# THREADPOOLCTL_VERSION: 'min' +# COVERAGE: 'false' +# # Linux + Python 3.7 build with OpenBLAS and without SITE_JOBLIB +# py37_conda_defaults_openblas: +# DISTRIB: 'conda' +# CONDA_CHANNEL: 'defaults' # Anaconda main channel +# PYTHON_VERSION: '3.7' +# BLAS: 'openblas' +# NUMPY_VERSION: 'min' +# SCIPY_VERSION: 'min' +# MATPLOTLIB_VERSION: 'min' +# THREADPOOLCTL_VERSION: '2.2.0' +# SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1' +# # Linux environment to test the latest available dependencies and MKL. +# # It runs tests requiring lightgbm, pandas and PyAMG. +# pylatest_pip_openblas_pandas: +# DISTRIB: 'conda-pip-latest' +# PYTHON_VERSION: '3.9' +# PANDAS_VERSION: 'none' +# CHECK_PYTEST_SOFT_DEPENDENCY: 'true' +# TEST_DOCSTRINGS: 'true' +# CHECK_WARNINGS: 'true' - template: build_tools/azure/posix-docker.yml parameters: @@ -232,50 +232,50 @@ jobs: PYTEST_VERSION: 'min' THREADPOOLCTL_VERSION: '2.2.0' -- template: build_tools/azure/posix.yml - parameters: - name: macOS - vmImage: macOS-10.15 - dependsOn: [linting, git_commit] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - ne(variables['Build.Reason'], 'Schedule') - ) - matrix: - pylatest_conda_forge_mkl: - DISTRIB: 'conda' - BLAS: 'mkl' - CONDA_CHANNEL: 'conda-forge' - CPU_COUNT: '3' - pylatest_conda_mkl_no_openmp: - DISTRIB: 'conda' - BLAS: 'mkl' - SKLEARN_TEST_NO_OPENMP: 'true' - SKLEARN_SKIP_OPENMP_TEST: 'true' - CPU_COUNT: '3' +# - template: build_tools/azure/posix.yml +# parameters: +# name: macOS +# vmImage: macOS-10.15 +# dependsOn: [linting, git_commit] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# ne(variables['Build.Reason'], 'Schedule') +# ) +# matrix: +# pylatest_conda_forge_mkl: +# DISTRIB: 'conda' +# BLAS: 'mkl' +# CONDA_CHANNEL: 'conda-forge' +# CPU_COUNT: '3' +# pylatest_conda_mkl_no_openmp: +# DISTRIB: 'conda' +# BLAS: 'mkl' +# SKLEARN_TEST_NO_OPENMP: 'true' +# SKLEARN_SKIP_OPENMP_TEST: 'true' +# CPU_COUNT: '3' -- template: build_tools/azure/windows.yml - parameters: - name: Windows - vmImage: windows-latest - dependsOn: [linting, git_commit] - condition: | - and( - succeeded(), - not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), - ne(variables['Build.Reason'], 'Schedule') - ) - matrix: - py37_conda_forge_mkl: - DISTRIB: 'conda' - CONDA_CHANNEL: 'conda-forge' - PYTHON_VERSION: '3.7' - CHECK_WARNINGS: 'true' - PYTHON_ARCH: '64' - PYTEST_VERSION: '*' - COVERAGE: 'true' - py37_pip_openblas_32bit: - PYTHON_VERSION: '3.7' - PYTHON_ARCH: '32' +# - template: build_tools/azure/windows.yml +# parameters: +# name: Windows +# vmImage: windows-latest +# dependsOn: [linting, git_commit] +# condition: | +# and( +# succeeded(), +# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), +# ne(variables['Build.Reason'], 'Schedule') +# ) +# matrix: +# py37_conda_forge_mkl: +# DISTRIB: 'conda' +# CONDA_CHANNEL: 'conda-forge' +# PYTHON_VERSION: '3.7' +# CHECK_WARNINGS: 'true' +# PYTHON_ARCH: '64' +# PYTEST_VERSION: '*' +# COVERAGE: 'true' +# py37_pip_openblas_32bit: +# PYTHON_VERSION: '3.7' +# PYTHON_ARCH: '32' diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 2a896c1cf66f7..9f4c0d1a32ead 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -26,7 +26,7 @@ else conda list fi -TEST_CMD="python -m pytest --showlocals --durations=20 --junitxml=$JUNITXML" +TEST_CMD="python -m pytest -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" if [[ "$COVERAGE" == "true" ]]; then # Note: --cov-report= is used to disable to long text output report in the From 9d6e6987ff4fbcd32fc9a07944b260688162e14b Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 19 Dec 2021 11:53:31 +0100 Subject: [PATCH 28/44] DEBUG test only linear_model module --- build_tools/azure/test_script.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 9f4c0d1a32ead..598603717c231 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -26,7 +26,7 @@ else conda list fi -TEST_CMD="python -m pytest -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" +TEST_CMD="python -m pytest -s -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" if [[ "$COVERAGE" == "true" ]]; then # Note: --cov-report= is used to disable to long text output report in the @@ -55,5 +55,5 @@ if [[ "$SHOW_SHORT_SUMMARY" == "true" ]]; then fi set -x -eval "$TEST_CMD --pyargs sklearn" +eval "$TEST_CMD --pyargs sklearn.linear_model" set +x From 41d14974292c34f8e16b53361ab1e134c18c7a21 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 19 Dec 2021 13:01:28 +0100 Subject: [PATCH 29/44] Revert "DEBUG test only linear_model module" This reverts commit 9d6e6987ff4fbcd32fc9a07944b260688162e14b. --- build_tools/azure/test_script.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 598603717c231..9f4c0d1a32ead 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -26,7 +26,7 @@ else conda list fi -TEST_CMD="python -m pytest -s -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" +TEST_CMD="python -m pytest -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" if [[ "$COVERAGE" == "true" ]]; then # Note: --cov-report= is used to disable to long text output report in the @@ -55,5 +55,5 @@ if [[ "$SHOW_SHORT_SUMMARY" == "true" ]]; then fi set -x -eval "$TEST_CMD --pyargs sklearn.linear_model" +eval "$TEST_CMD --pyargs sklearn" set +x From c20316704185da400857b0a3f32935ee1b56c8d9 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 19 Dec 2021 13:02:59 +0100 Subject: [PATCH 30/44] DEBUG test -k LogisticRegression --- build_tools/azure/test_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 9f4c0d1a32ead..3e2c7346b3f2a 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -55,5 +55,5 @@ if [[ "$SHOW_SHORT_SUMMARY" == "true" ]]; then fi set -x -eval "$TEST_CMD --pyargs sklearn" +eval "$TEST_CMD -k LogisticRegression --pyargs sklearn" set +x From edc87aeb31f04685bac43440e6861fd85fb9e6b5 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 19 Dec 2021 13:30:44 +0100 Subject: [PATCH 31/44] Revert "DEBUG test -k LogisticRegression" This reverts commit c20316704185da400857b0a3f32935ee1b56c8d9. --- build_tools/azure/test_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 3e2c7346b3f2a..9f4c0d1a32ead 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -55,5 +55,5 @@ if [[ "$SHOW_SHORT_SUMMARY" == "true" ]]; then fi set -x -eval "$TEST_CMD -k LogisticRegression --pyargs sklearn" +eval "$TEST_CMD --pyargs sklearn" set +x From 3b890fd127ba48c64e2e12f7764fe8b0cf184ba9 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 19 Dec 2021 13:30:51 +0100 Subject: [PATCH 32/44] Revert "DEBUG debian-32" This reverts commit ef0b98f23251d1b2b0bd8801e456f258392a8d18. --- azure-pipelines.yml | 388 +++++++++++++++---------------- build_tools/azure/test_script.sh | 2 +- 2 files changed, 195 insertions(+), 195 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d92d2831a8e2d..2db2aafb8cc95 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -57,159 +57,159 @@ jobs: mypy sklearn/ displayName: Run mypy -# - template: build_tools/azure/posix.yml -# parameters: -# name: Linux_Nightly -# vmImage: ubuntu-20.04 -# dependsOn: [git_commit, linting] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# or(eq(variables['Build.Reason'], 'Schedule'), -# contains(dependencies['git_commit']['outputs']['commit.message'], '[scipy-dev]' -# ) -# ) -# ) -# matrix: -# pylatest_pip_scipy_dev: -# DISTRIB: 'conda-pip-scipy-dev' -# PYTHON_VERSION: '*' -# CHECK_WARNINGS: 'true' -# CHECK_PYTEST_SOFT_DEPENDENCY: 'true' -# TEST_DOCSTRINGS: 'true' -# # Tests that require large downloads over the networks are skipped in CI. -# # Here we make sure, that they are still run on a regular basis. -# SKLEARN_SKIP_NETWORK_TESTS: '0' -# CREATE_ISSUE_ON_TRACKER: 'true' +- template: build_tools/azure/posix.yml + parameters: + name: Linux_Nightly + vmImage: ubuntu-20.04 + dependsOn: [git_commit, linting] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + or(eq(variables['Build.Reason'], 'Schedule'), + contains(dependencies['git_commit']['outputs']['commit.message'], '[scipy-dev]' + ) + ) + ) + matrix: + pylatest_pip_scipy_dev: + DISTRIB: 'conda-pip-scipy-dev' + PYTHON_VERSION: '*' + CHECK_WARNINGS: 'true' + CHECK_PYTEST_SOFT_DEPENDENCY: 'true' + TEST_DOCSTRINGS: 'true' + # Tests that require large downloads over the networks are skipped in CI. + # Here we make sure, that they are still run on a regular basis. + SKLEARN_SKIP_NETWORK_TESTS: '0' + CREATE_ISSUE_ON_TRACKER: 'true' -# # Check compilation with intel C++ compiler (ICC) -# - template: build_tools/azure/posix.yml -# parameters: -# name: Linux_Nightly_ICC -# vmImage: ubuntu-20.04 -# dependsOn: [git_commit, linting] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# or(eq(variables['Build.Reason'], 'Schedule'), -# contains(dependencies['git_commit']['outputs']['commit.message'], '[icc-build]') -# ) -# ) -# matrix: -# pylatest_conda_forge_mkl: -# DISTRIB: 'conda' -# CONDA_CHANNEL: 'conda-forge' -# PYTHON_VERSION: '*' -# BLAS: 'mkl' -# COVERAGE: 'false' -# BUILD_WITH_ICC: 'true' +# Check compilation with intel C++ compiler (ICC) +- template: build_tools/azure/posix.yml + parameters: + name: Linux_Nightly_ICC + vmImage: ubuntu-20.04 + dependsOn: [git_commit, linting] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + or(eq(variables['Build.Reason'], 'Schedule'), + contains(dependencies['git_commit']['outputs']['commit.message'], '[icc-build]') + ) + ) + matrix: + pylatest_conda_forge_mkl: + DISTRIB: 'conda' + CONDA_CHANNEL: 'conda-forge' + PYTHON_VERSION: '*' + BLAS: 'mkl' + COVERAGE: 'false' + BUILD_WITH_ICC: 'true' -# - template: build_tools/azure/posix-docker.yml -# parameters: -# name: Linux_Nightly_PyPy -# vmImage: ubuntu-20.04 -# dependsOn: [linting, git_commit] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# or( -# eq(variables['Build.Reason'], 'Schedule'), -# contains(dependencies['git_commit']['outputs']['commit.message'], '[pypy]') -# ) -# ) -# matrix: -# pypy3: -# DISTRIB: 'conda-mamba-pypy3' -# DOCKER_CONTAINER: 'condaforge/mambaforge-pypy3:4.10.3-5' -# PILLOW_VERSION: 'none' -# PANDAS_VERSION: 'none' -# CREATE_ISSUE_ON_TRACKER: 'true' +- template: build_tools/azure/posix-docker.yml + parameters: + name: Linux_Nightly_PyPy + vmImage: ubuntu-20.04 + dependsOn: [linting, git_commit] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + or( + eq(variables['Build.Reason'], 'Schedule'), + contains(dependencies['git_commit']['outputs']['commit.message'], '[pypy]') + ) + ) + matrix: + pypy3: + DISTRIB: 'conda-mamba-pypy3' + DOCKER_CONTAINER: 'condaforge/mambaforge-pypy3:4.10.3-5' + PILLOW_VERSION: 'none' + PANDAS_VERSION: 'none' + CREATE_ISSUE_ON_TRACKER: 'true' -# # Will run all the time regardless of linting outcome. -# - template: build_tools/azure/posix.yml -# parameters: -# name: Linux_Runs -# vmImage: ubuntu-20.04 -# dependsOn: [git_commit] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')) -# ) -# matrix: -# pylatest_conda_forge_mkl: -# DISTRIB: 'conda' -# CONDA_CHANNEL: 'conda-forge' -# PYTHON_VERSION: '*' -# BLAS: 'mkl' -# COVERAGE: 'true' -# SHOW_SHORT_SUMMARY: 'true' +# Will run all the time regardless of linting outcome. +- template: build_tools/azure/posix.yml + parameters: + name: Linux_Runs + vmImage: ubuntu-20.04 + dependsOn: [git_commit] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')) + ) + matrix: + pylatest_conda_forge_mkl: + DISTRIB: 'conda' + CONDA_CHANNEL: 'conda-forge' + PYTHON_VERSION: '*' + BLAS: 'mkl' + COVERAGE: 'true' + SHOW_SHORT_SUMMARY: 'true' -# # Check compilation with Ubuntu bionic 18.04 LTS and scipy from conda-forge -# - template: build_tools/azure/posix.yml -# parameters: -# name: Ubuntu_Bionic -# vmImage: ubuntu-18.04 -# dependsOn: [git_commit, linting] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# ne(variables['Build.Reason'], 'Schedule') -# ) -# matrix: -# py37_conda_forge_openblas_ubuntu_1804: -# DISTRIB: 'conda' -# CONDA_CHANNEL: 'conda-forge' -# PYTHON_VERSION: '3.7' -# BLAS: 'openblas' -# COVERAGE: 'false' -# BUILD_WITH_ICC: 'false' +# Check compilation with Ubuntu bionic 18.04 LTS and scipy from conda-forge +- template: build_tools/azure/posix.yml + parameters: + name: Ubuntu_Bionic + vmImage: ubuntu-18.04 + dependsOn: [git_commit, linting] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + ne(variables['Build.Reason'], 'Schedule') + ) + matrix: + py37_conda_forge_openblas_ubuntu_1804: + DISTRIB: 'conda' + CONDA_CHANNEL: 'conda-forge' + PYTHON_VERSION: '3.7' + BLAS: 'openblas' + COVERAGE: 'false' + BUILD_WITH_ICC: 'false' -# - template: build_tools/azure/posix.yml -# parameters: -# name: Linux -# vmImage: ubuntu-20.04 -# dependsOn: [linting, git_commit] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# ne(variables['Build.Reason'], 'Schedule') -# ) -# matrix: -# # Linux environment to test that scikit-learn can be built against -# # versions of numpy, scipy with ATLAS that comes with Ubuntu Focal 20.04 -# # i.e. numpy 1.17.4 and scipy 1.3.3 -# ubuntu_atlas: -# DISTRIB: 'ubuntu' -# JOBLIB_VERSION: 'min' -# PANDAS_VERSION: 'none' -# THREADPOOLCTL_VERSION: 'min' -# COVERAGE: 'false' -# # Linux + Python 3.7 build with OpenBLAS and without SITE_JOBLIB -# py37_conda_defaults_openblas: -# DISTRIB: 'conda' -# CONDA_CHANNEL: 'defaults' # Anaconda main channel -# PYTHON_VERSION: '3.7' -# BLAS: 'openblas' -# NUMPY_VERSION: 'min' -# SCIPY_VERSION: 'min' -# MATPLOTLIB_VERSION: 'min' -# THREADPOOLCTL_VERSION: '2.2.0' -# SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1' -# # Linux environment to test the latest available dependencies and MKL. -# # It runs tests requiring lightgbm, pandas and PyAMG. -# pylatest_pip_openblas_pandas: -# DISTRIB: 'conda-pip-latest' -# PYTHON_VERSION: '3.9' -# PANDAS_VERSION: 'none' -# CHECK_PYTEST_SOFT_DEPENDENCY: 'true' -# TEST_DOCSTRINGS: 'true' -# CHECK_WARNINGS: 'true' +- template: build_tools/azure/posix.yml + parameters: + name: Linux + vmImage: ubuntu-20.04 + dependsOn: [linting, git_commit] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + ne(variables['Build.Reason'], 'Schedule') + ) + matrix: + # Linux environment to test that scikit-learn can be built against + # versions of numpy, scipy with ATLAS that comes with Ubuntu Focal 20.04 + # i.e. numpy 1.17.4 and scipy 1.3.3 + ubuntu_atlas: + DISTRIB: 'ubuntu' + JOBLIB_VERSION: 'min' + PANDAS_VERSION: 'none' + THREADPOOLCTL_VERSION: 'min' + COVERAGE: 'false' + # Linux + Python 3.7 build with OpenBLAS and without SITE_JOBLIB + py37_conda_defaults_openblas: + DISTRIB: 'conda' + CONDA_CHANNEL: 'defaults' # Anaconda main channel + PYTHON_VERSION: '3.7' + BLAS: 'openblas' + NUMPY_VERSION: 'min' + SCIPY_VERSION: 'min' + MATPLOTLIB_VERSION: 'min' + THREADPOOLCTL_VERSION: '2.2.0' + SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1' + # Linux environment to test the latest available dependencies and MKL. + # It runs tests requiring lightgbm, pandas and PyAMG. + pylatest_pip_openblas_pandas: + DISTRIB: 'conda-pip-latest' + PYTHON_VERSION: '3.9' + PANDAS_VERSION: 'none' + CHECK_PYTEST_SOFT_DEPENDENCY: 'true' + TEST_DOCSTRINGS: 'true' + CHECK_WARNINGS: 'true' - template: build_tools/azure/posix-docker.yml parameters: @@ -232,50 +232,50 @@ jobs: PYTEST_VERSION: 'min' THREADPOOLCTL_VERSION: '2.2.0' -# - template: build_tools/azure/posix.yml -# parameters: -# name: macOS -# vmImage: macOS-10.15 -# dependsOn: [linting, git_commit] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# ne(variables['Build.Reason'], 'Schedule') -# ) -# matrix: -# pylatest_conda_forge_mkl: -# DISTRIB: 'conda' -# BLAS: 'mkl' -# CONDA_CHANNEL: 'conda-forge' -# CPU_COUNT: '3' -# pylatest_conda_mkl_no_openmp: -# DISTRIB: 'conda' -# BLAS: 'mkl' -# SKLEARN_TEST_NO_OPENMP: 'true' -# SKLEARN_SKIP_OPENMP_TEST: 'true' -# CPU_COUNT: '3' +- template: build_tools/azure/posix.yml + parameters: + name: macOS + vmImage: macOS-10.15 + dependsOn: [linting, git_commit] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + ne(variables['Build.Reason'], 'Schedule') + ) + matrix: + pylatest_conda_forge_mkl: + DISTRIB: 'conda' + BLAS: 'mkl' + CONDA_CHANNEL: 'conda-forge' + CPU_COUNT: '3' + pylatest_conda_mkl_no_openmp: + DISTRIB: 'conda' + BLAS: 'mkl' + SKLEARN_TEST_NO_OPENMP: 'true' + SKLEARN_SKIP_OPENMP_TEST: 'true' + CPU_COUNT: '3' -# - template: build_tools/azure/windows.yml -# parameters: -# name: Windows -# vmImage: windows-latest -# dependsOn: [linting, git_commit] -# condition: | -# and( -# succeeded(), -# not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), -# ne(variables['Build.Reason'], 'Schedule') -# ) -# matrix: -# py37_conda_forge_mkl: -# DISTRIB: 'conda' -# CONDA_CHANNEL: 'conda-forge' -# PYTHON_VERSION: '3.7' -# CHECK_WARNINGS: 'true' -# PYTHON_ARCH: '64' -# PYTEST_VERSION: '*' -# COVERAGE: 'true' -# py37_pip_openblas_32bit: -# PYTHON_VERSION: '3.7' -# PYTHON_ARCH: '32' +- template: build_tools/azure/windows.yml + parameters: + name: Windows + vmImage: windows-latest + dependsOn: [linting, git_commit] + condition: | + and( + succeeded(), + not(contains(dependencies['git_commit']['outputs']['commit.message'], '[ci skip]')), + ne(variables['Build.Reason'], 'Schedule') + ) + matrix: + py37_conda_forge_mkl: + DISTRIB: 'conda' + CONDA_CHANNEL: 'conda-forge' + PYTHON_VERSION: '3.7' + CHECK_WARNINGS: 'true' + PYTHON_ARCH: '64' + PYTEST_VERSION: '*' + COVERAGE: 'true' + py37_pip_openblas_32bit: + PYTHON_VERSION: '3.7' + PYTHON_ARCH: '32' diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 9f4c0d1a32ead..2a896c1cf66f7 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -26,7 +26,7 @@ else conda list fi -TEST_CMD="python -m pytest -vv --tb=long --showlocals --durations=20 --junitxml=$JUNITXML" +TEST_CMD="python -m pytest --showlocals --durations=20 --junitxml=$JUNITXML" if [[ "$COVERAGE" == "true" ]]; then # Note: --cov-report= is used to disable to long text output report in the From c7f6f72a8c1ee21299786130e097df248fc1a0fb Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 20 Dec 2021 08:54:26 +0100 Subject: [PATCH 33/44] DEBUG set n_jobs=1 --- sklearn/ensemble/tests/test_voting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 4bebfaca53709..59a71d7007f5d 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -291,7 +291,7 @@ def test_parallel_fit(): estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=1 ).fit(X, y) eclf2 = VotingClassifier( - estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=2 + estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=1 ).fit(X, y) assert_array_equal(eclf1.predict(X), eclf2.predict(X)) From 028fc616485fa94a0ea5dbd72b2e2498210737c8 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 20 Dec 2021 12:57:03 +0100 Subject: [PATCH 34/44] Revert "DEBUG set n_jobs=1" This reverts commit c7f6f72a8c1ee21299786130e097df248fc1a0fb. --- sklearn/ensemble/tests/test_voting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 59a71d7007f5d..4bebfaca53709 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -291,7 +291,7 @@ def test_parallel_fit(): estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=1 ).fit(X, y) eclf2 = VotingClassifier( - estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=1 + estimators=[("lr", clf1), ("rf", clf2), ("gnb", clf3)], voting="soft", n_jobs=2 ).fit(X, y) assert_array_equal(eclf1.predict(X), eclf2.predict(X)) From aadf1d2d78395cb447b0c8bbbcca2224bda8fb29 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 20 Dec 2021 19:03:57 +0100 Subject: [PATCH 35/44] CLN always use n_threads=1 --- sklearn/linear_model/_logistic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 37814a1aaf345..18e6bdeb094a8 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -25,7 +25,6 @@ from ..svm._base import _fit_liblinear from ..utils import check_array, check_consistent_length, compute_class_weight from ..utils import check_random_state -from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils.extmath import softmax from ..utils.extmath import row_norms from ..utils.optimize import _newton_cg, _check_optimize_result @@ -1225,7 +1224,9 @@ def fit(self, X, y, sample_weight=None): and len(classes_) == 1 and effective_n_jobs(self.n_jobs) == 1 ): - n_threads = _openmp_effective_n_threads() + # In the future, we would like n_threads = _openmp_effective_n_threads() + # For the time being, we just do + n_threads = 1 else: n_threads = 1 From a45d94c6dac340a9187cc17157cef503dbec0ff8 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 21 Dec 2021 16:54:59 +0100 Subject: [PATCH 36/44] CLN address review --- sklearn/linear_model/_linear_loss.py | 33 ++++++++++--------- sklearn/linear_model/_logistic.py | 4 +-- .../linear_model/tests/test_linear_loss.py | 14 ++++---- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 478a97a60159d..6c2a660154c3d 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -182,7 +182,7 @@ def loss_gradient( The gradient of the loss. """ n_features, n_classes = X.shape[1], self._loss.n_classes - n_dof = n_features + self.fit_intercept + n_dof = n_features + int(self.fit_intercept) w, intercept, raw_prediction = self._w_intercept_raw(coef, X) loss, gradient_per_sample = self._loss.loss_gradient( @@ -199,18 +199,17 @@ def loss_gradient( grad[:n_features] = X.T @ gradient_per_sample + l2_reg_strength * w if self.fit_intercept: grad[-1] = gradient_per_sample.sum() - return loss, grad else: loss += 0.5 * l2_reg_strength * squared_norm(w) grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") - # gradient.shape = (n_samples, n_classes) + # gradient_per_sample.shape = (n_samples, n_classes) grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w if self.fit_intercept: grad[:, -1] = gradient_per_sample.sum(axis=0) if coef.ndim == 1: - return loss, grad.ravel(order="F") - else: - return loss, grad + grad = grad.ravel(order="F") + + return loss, grad def gradient( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 @@ -241,7 +240,7 @@ def gradient( The gradient of the loss. """ n_features, n_classes = X.shape[1], self._loss.n_classes - n_dof = n_features + self.fit_intercept + n_dof = n_features + int(self.fit_intercept) w, intercept, raw_prediction = self._w_intercept_raw(coef, X) gradient_per_sample = self._loss.gradient( @@ -268,7 +267,7 @@ def gradient( else: return grad - def gradient_hessp( + def gradient_hessian_product( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 ): """Computes gradient and hessp (hessian product function). @@ -301,7 +300,7 @@ def gradient_hessp( and returns matrix-vector product with hessian. """ (n_samples, n_features), n_classes = X.shape, self._loss.n_classes - n_dof = n_features + self.fit_intercept + n_dof = n_features + int(self.fit_intercept) w, intercept, raw_prediction = self._w_intercept_raw(coef, X) if not self._loss.is_multiclass: @@ -332,7 +331,7 @@ def gradient_hessp( # res = (X, 1)' @ diag(h) @ (X, 1) @ s # = (X, 1)' @ (hX @ s[:n_features], sum(h) * s[-1]) # res[:n_features] = X' @ hX @ s[:n_features] + sum(h) * s[-1] - # res[:-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1] + # res[-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1] def hessp(s): ret = np.empty_like(s) if sparse.issparse(X): @@ -388,7 +387,7 @@ def hessp(s): s = s.reshape((n_classes, -1), order="F") # shape = (n_classes, n_dof) if self.fit_intercept: s_intercept = s[:, -1] - s = s[:, :-1] + s = s[:, :-1] # shape = (n_classes, n_features) else: s_intercept = 0 tmp = X @ s.T + s_intercept # X_{im} * s_k_m @@ -396,7 +395,9 @@ def hessp(s): tmp *= proba # * p_i_k if sample_weight is not None: tmp *= sample_weight[:, np.newaxis] - hess_prod = np.empty_like(grad, order="F") + # hess_prod = empty_like(grad), but we ravel grad below and this + # function is run after that. + hess_prod = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s if self.fit_intercept: hess_prod[:, -1] = tmp.sum(axis=0) @@ -405,7 +406,7 @@ def hessp(s): else: return hess_prod - if coef.ndim == 1: - return grad.ravel(order="F"), hessp - else: - return grad, hessp + if coef.ndim == 1: + return grad.ravel(order="F"), hessp + + return grad, hessp diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 18e6bdeb094a8..ea7652ae2ef9c 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -408,7 +408,7 @@ def _logistic_regression_path( elif solver == "newton-cg": func = loss.loss grad = loss.gradient - hess = loss.gradient_hessp # hess = [gradient, hessp] + hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": w0.T} else: target = y_bin @@ -419,7 +419,7 @@ def _logistic_regression_path( loss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) func = loss.loss grad = loss.gradient - hess = loss.gradient_hessp # hess = [gradient, hessp] + hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} coefs = list() diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py index 97bedfb896a88..07d236c50e7c5 100644 --- a/sklearn/linear_model/tests/test_linear_loss.py +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -101,7 +101,7 @@ def test_loss_gradients_are_the_same( l2, g2 = loss.loss_gradient( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) - g3, h3 = loss.gradient_hessp( + g3, h3 = loss.gradient_hessian_product( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) @@ -120,7 +120,7 @@ def test_loss_gradients_are_the_same( l2_sp, g2_sp = loss.loss_gradient( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) - g3_sp, h3_sp = loss.gradient_hessp( + g3_sp, h3_sp = loss.gradient_hessian_product( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) @@ -160,13 +160,13 @@ def test_loss_gradients_hessp_intercept( l, g = loss.loss_gradient( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) - _, hessp = loss.gradient_hessp( + _, hessp = loss.gradient_hessian_product( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) l_inter, g_inter = loss_inter.loss_gradient( coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) - _, hessp_inter = loss_inter.gradient_hessp( + _, hessp_inter = loss_inter.gradient_hessian_product( coef, X_inter, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) @@ -211,7 +211,7 @@ def test_gradients_hessians_numerically( # 1. Check gradients numerically eps = 1e-6 - g, hessp = loss.gradient_hessp( + g, hessp = loss.gradient_hessian_product( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) # Use a trick to get central finte difference of accuracy 4 (five-point stencil) @@ -280,7 +280,7 @@ def test_multinomial_coef_shape(fit_intercept): l, g = loss.loss_gradient(coef, X, y) g1 = loss.gradient(coef, X, y) - g2, hessp = loss.gradient_hessp(coef, X, y) + g2, hessp = loss.gradient_hessian_product(coef, X, y) h = hessp(s) assert g.shape == coef.shape assert h.shape == coef.shape @@ -291,7 +291,7 @@ def test_multinomial_coef_shape(fit_intercept): s_r = s.ravel(order="F") l_r, g_r = loss.loss_gradient(coef_r, X, y) g1_r = loss.gradient(coef_r, X, y) - g2_r, hessp_r = loss.gradient_hessp(coef_r, X, y) + g2_r, hessp_r = loss.gradient_hessian_product(coef_r, X, y) h_r = hessp_r(s_r) assert g_r.shape == coef_r.shape assert h_r.shape == coef_r.shape From cb2a8a2456780705e5031018910ea040abf9e512 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 21 Dec 2021 16:55:40 +0100 Subject: [PATCH 37/44] ENH avoid array copy --- sklearn/linear_model/_linear_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 6c2a660154c3d..7ee2f03293647 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -325,7 +325,7 @@ def gradient_hessian_product( if self.fit_intercept: # Calculate the double derivative with respect to intercept. # Note: In case hX is sparse, hX.sum is a matrix object. - hh_intercept = np.squeeze(np.array(hX.sum(axis=0))) + hh_intercept = np.squeeze(np.asarray(hX.sum(axis=0))) # With intercept included and l2_reg_strength = 0, hessp returns # res = (X, 1)' @ diag(h) @ (X, 1) @ s From 08be48476b401311687d9aedadd41458cfe3bb54 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 21 Dec 2021 16:58:07 +0100 Subject: [PATCH 38/44] CLN simplify L2 norm --- sklearn/linear_model/_linear_loss.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 7ee2f03293647..799545f75d55e 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -145,10 +145,8 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) ) loss = loss.sum() - if w.ndim == 1: - return loss + 0.5 * l2_reg_strength * (w @ w) - else: - return loss + 0.5 * l2_reg_strength * squared_norm(w) + norm2_w = w @ w if w.ndim == 1 else squared_norm(w) + return loss + 0.5 * l2_reg_strength * norm2_w def loss_gradient( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 From 6325270fbaf8aaab539df2f478d953ee2bf49abb Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 21 Dec 2021 17:15:34 +0100 Subject: [PATCH 39/44] CLN rename w to weights --- sklearn/linear_model/_linear_loss.py | 60 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 799545f75d55e..c2e2d86c906a6 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -79,7 +79,7 @@ def _w_intercept_raw(self, coef, X): Returns ------- - w : ndarray of shape (n_features,) or (n_classes, n_features) + weights : ndarray of shape (n_features,) or (n_classes, n_features) Coefficients without intercept term. intercept : float or ndarray of shape (n_classes,) Intercept terms. @@ -89,25 +89,25 @@ def _w_intercept_raw(self, coef, X): if not self._loss.is_multiclass: if self.fit_intercept: intercept = coef[-1] - w = coef[:-1] + weights = coef[:-1] else: intercept = 0.0 - w = coef - raw_prediction = X @ w + intercept + weights = coef + raw_prediction = X @ weights + intercept else: # reshape to (n_classes, n_dof) if coef.ndim == 1: - w = coef.reshape((self._loss.n_classes, -1), order="F") + weights = coef.reshape((self._loss.n_classes, -1), order="F") else: - w = coef + weights = coef if self.fit_intercept: - intercept = w[:, -1] - w = w[:, :-1] + intercept = weights[:, -1] + weights = weights[:, :-1] else: intercept = 0.0 - raw_prediction = X @ w.T + intercept # ndarray, likely C-contiguous + raw_prediction = X @ weights.T + intercept # ndarray, likely C-contiguous - return w, intercept, raw_prediction + return weights, intercept, raw_prediction def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1): """Compute the loss as sum over point-wise losses. @@ -135,7 +135,7 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) loss : float Sum of losses per sample plus penalty. """ - w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) loss = self._loss.loss( y_true=y, @@ -145,7 +145,7 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) ) loss = loss.sum() - norm2_w = w @ w if w.ndim == 1 else squared_norm(w) + norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights) return loss + 0.5 * l2_reg_strength * norm2_w def loss_gradient( @@ -181,9 +181,9 @@ def loss_gradient( """ n_features, n_classes = X.shape[1], self._loss.n_classes n_dof = n_features + int(self.fit_intercept) - w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - loss, gradient_per_sample = self._loss.loss_gradient( + loss, grad_per_sample = self._loss.loss_gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -192,18 +192,18 @@ def loss_gradient( loss = loss.sum() if not self._loss.is_multiclass: - loss += 0.5 * l2_reg_strength * (w @ w) + loss += 0.5 * l2_reg_strength * (weights @ weights) grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient_per_sample + l2_reg_strength * w + grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights if self.fit_intercept: - grad[-1] = gradient_per_sample.sum() + grad[-1] = grad_per_sample.sum() else: - loss += 0.5 * l2_reg_strength * squared_norm(w) + loss += 0.5 * l2_reg_strength * squared_norm(weights) grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") - # gradient_per_sample.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w + # grad_per_sample.shape = (n_samples, n_classes) + grad[:, :n_features] = grad_per_sample.T @ X + l2_reg_strength * weights if self.fit_intercept: - grad[:, -1] = gradient_per_sample.sum(axis=0) + grad[:, -1] = grad_per_sample.sum(axis=0) if coef.ndim == 1: grad = grad.ravel(order="F") @@ -239,9 +239,9 @@ def gradient( """ n_features, n_classes = X.shape[1], self._loss.n_classes n_dof = n_features + int(self.fit_intercept) - w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - gradient_per_sample = self._loss.gradient( + grad_per_sample = self._loss.gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -250,16 +250,16 @@ def gradient( if not self._loss.is_multiclass: grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient_per_sample + l2_reg_strength * w + grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights if self.fit_intercept: - grad[-1] = gradient_per_sample.sum() + grad[-1] = grad_per_sample.sum() return grad else: grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") # gradient.shape = (n_samples, n_classes) - grad[:, :n_features] = gradient_per_sample.T @ X + l2_reg_strength * w + grad[:, :n_features] = grad_per_sample.T @ X + l2_reg_strength * weights if self.fit_intercept: - grad[:, -1] = gradient_per_sample.sum(axis=0) + grad[:, -1] = grad_per_sample.sum(axis=0) if coef.ndim == 1: return grad.ravel(order="F") else: @@ -299,7 +299,7 @@ def gradient_hessian_product( """ (n_samples, n_features), n_classes = X.shape, self._loss.n_classes n_dof = n_features + int(self.fit_intercept) - w, intercept, raw_prediction = self._w_intercept_raw(coef, X) + weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) if not self._loss.is_multiclass: gradient, hessian = self._loss.gradient_hessian( @@ -309,7 +309,7 @@ def gradient_hessian_product( n_threads=n_threads, ) grad = np.empty_like(coef, dtype=X.dtype) - grad[:n_features] = X.T @ gradient + l2_reg_strength * w + grad[:n_features] = X.T @ gradient + l2_reg_strength * weights if self.fit_intercept: grad[-1] = gradient.sum() @@ -356,7 +356,7 @@ def hessp(s): n_threads=n_threads, ) grad = np.empty((n_classes, n_dof), dtype=X.dtype, order="F") - grad[:, :n_features] = gradient.T @ X + l2_reg_strength * w + grad[:, :n_features] = gradient.T @ X + l2_reg_strength * weights if self.fit_intercept: grad[:, -1] = gradient.sum(axis=0) From c2b599b87b7129683e5eecfbddb00097aa4a0a59 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 21 Dec 2021 17:21:14 +0100 Subject: [PATCH 40/44] CLN rename to hessian_sum and hx_sum --- sklearn/linear_model/_linear_loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index c2e2d86c906a6..e6291b19f3fad 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -313,8 +313,8 @@ def gradient_hessian_product( if self.fit_intercept: grad[-1] = gradient.sum() - # Precompute as much as possible: hX, hh_intercept and hsum - hsum = hessian.sum() + # Precompute as much as possible: hX, hX_sum and hessian_sum + hessian_sum = hessian.sum() if sparse.issparse(X): hX = sparse.dia_matrix((hessian, 0), shape=(n_samples, n_samples)) @ X else: @@ -323,7 +323,7 @@ def gradient_hessian_product( if self.fit_intercept: # Calculate the double derivative with respect to intercept. # Note: In case hX is sparse, hX.sum is a matrix object. - hh_intercept = np.squeeze(np.asarray(hX.sum(axis=0))) + hX_sum = np.squeeze(np.asarray(hX.sum(axis=0))) # With intercept included and l2_reg_strength = 0, hessp returns # res = (X, 1)' @ diag(h) @ (X, 1) @ s @@ -339,8 +339,8 @@ def hessp(s): ret[:n_features] += l2_reg_strength * s[:n_features] if self.fit_intercept: - ret[:n_features] += s[-1] * hh_intercept - ret[-1] = hh_intercept @ s[:n_features] + hsum * s[-1] + ret[:n_features] += s[-1] * hX_sum + ret[-1] = hX_sum @ s[:n_features] + hessian_sum * s[-1] return ret else: From b24afbb00a264634b915074010593bbdce0008c3 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 2 Jan 2022 14:57:26 +0100 Subject: [PATCH 41/44] CLN address review --- sklearn/linear_model/_linear_loss.py | 35 ++++++++++--------- sklearn/linear_model/_logistic.py | 6 ++-- .../linear_model/tests/test_linear_loss.py | 20 +++++++---- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index e6291b19f3fad..1c13ac8318d2b 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -9,7 +9,8 @@ class LinearModelLoss: """General class for loss functions with raw_prediction = X @ coef + intercept. - The loss is the sum of per sample losses and includes an L2 term:: + The loss is the sum of per sample losses and includes a term for L2 + regularization:: loss = sum_i s_i loss(y_i, X_i @ coef + intercept) + 1/2 * l2_reg_strength * ||coef||_2^2 @@ -119,13 +120,13 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) If shape (n_classes * n_dof,), the classes of one feature are contiguous, i.e. one reconstructs the 2d-array via coef.reshape((n_classes, -1), order="F"). - y : contiguous array of shape (n_samples,) - Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None Sample weights. - l2_reg_strength: float + l2_reg_strength : float, default=0.0 L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -160,13 +161,13 @@ def loss_gradient( If shape (n_classes * n_dof,), the classes of one feature are contiguous, i.e. one reconstructs the 2d-array via coef.reshape((n_classes, -1), order="F"). - y : contiguous array of shape (n_samples,) - Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None Sample weights. - l2_reg_strength: float + l2_reg_strength : float, default=0.0 L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -221,13 +222,13 @@ def gradient( If shape (n_classes * n_dof,), the classes of one feature are contiguous, i.e. one reconstructs the 2d-array via coef.reshape((n_classes, -1), order="F"). - y : contiguous array of shape (n_samples,) - Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None Sample weights. - l2_reg_strength: float + l2_reg_strength : float, default=0.0 L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. @@ -277,13 +278,13 @@ def gradient_hessian_product( If shape (n_classes * n_dof,), the classes of one feature are contiguous, i.e. one reconstructs the 2d-array via coef.reshape((n_classes, -1), order="F"). - y : contiguous array of shape (n_samples,) - Observed, true target values. X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. - sample_weight : None or contiguous array of shape (n_samples,) + y : contiguous array of shape (n_samples,) + Observed, true target values. + sample_weight : None or contiguous array of shape (n_samples,), default=None Sample weights. - l2_reg_strength: float + l2_reg_strength : float, default=0.0 L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index ea7652ae2ef9c..1d1b66db92d63 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -426,6 +426,7 @@ def _logistic_regression_path( n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): if solver == "lbfgs": + l2_reg_strength = 1.0 / C iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) ] @@ -434,7 +435,7 @@ def _logistic_regression_path( w0, method="L-BFGS-B", jac=True, - args=(X, target, sample_weight, 1.0 / C, n_threads), + args=(X, target, sample_weight, l2_reg_strength, n_threads), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, ) n_iter_i = _check_optimize_result( @@ -445,7 +446,8 @@ def _logistic_regression_path( ) w0, loss = opt_res.x, opt_res.fun elif solver == "newton-cg": - args = (X, target, sample_weight, 1.0 / C, n_threads) + l2_reg_strength = 1.0 / C + args = (X, target, sample_weight, l2_reg_strength, n_threads) w0, n_iter_i = _newton_cg( hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py index 07d236c50e7c5..61c211dc98f12 100644 --- a/sklearn/linear_model/tests/test_linear_loss.py +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -1,7 +1,8 @@ """ Tests for LinearModelLoss -Note that correctness of losses is already well covered in the _loss module. +Note that correctness of losses (which compose LinearModelLoss) is already well +covered in the _loss module. """ import pytest import numpy as np @@ -18,7 +19,7 @@ from sklearn.utils.extmath import squared_norm -# We don not need to test all losses, just what LinearModelLoss does on top of the +# We do not need to test all losses, just what LinearModelLoss does on top of the # base losses. LOSSES = [HalfBinomialLoss, HalfMultinomialLoss, HalfPoissonLoss] @@ -49,7 +50,7 @@ def random_X_y_coef( raw_prediction = X @ coef.T proba = linear_model_loss._loss.link.inverse(raw_prediction) - # y = rng.choice(np.arange(n_classes), p=proba) does not work, see + # y = rng.choice(np.arange(n_classes), p=proba) does not work. # See https://stackoverflow.com/a/34190035/16761084 def choice_vectorized(items, p): s = p.cumsum(axis=1) @@ -129,6 +130,7 @@ def test_loss_gradients_are_the_same( assert_allclose(g1, g1_sp) assert_allclose(g1, g2_sp) assert_allclose(g1, g3_sp) + assert_allclose(h3(g1), h3_sp(g1_sp)) @pytest.mark.parametrize("base_loss", LOSSES) @@ -214,9 +216,10 @@ def test_gradients_hessians_numerically( g, hessp = loss.gradient_hessian_product( coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength ) - # Use a trick to get central finte difference of accuracy 4 (five-point stencil) + # Use a trick to get central finite difference of accuracy 4 (five-point stencil) # https://en.wikipedia.org/wiki/Numerical_differentiation # https://en.wikipedia.org/wiki/Finite_difference_coefficient + # approx_g1 = (f(x + eps) - f(x - eps)) / (2*eps) approx_g1 = optimize.approx_fprime( coef, lambda coef: loss.loss( @@ -227,7 +230,8 @@ def test_gradients_hessians_numerically( l2_reg_strength=l2_reg_strength, ), 2 * eps, - ) # (f(x + eps) - f(x - eps)) / (2*eps) + ) + # approx_g2 = (f(x + 2*eps) - f(x - 2*eps)) / (4*eps) approx_g2 = optimize.approx_fprime( coef, lambda coef: loss.loss( @@ -238,8 +242,10 @@ def test_gradients_hessians_numerically( l2_reg_strength=l2_reg_strength, ), 4 * eps, - ) # (f(x + 2*eps) - f(x - 2*eps)) / (4*eps) - approx_g = 4 / 3 * approx_g1 - 1 / 3 * approx_g2 + ) + # Five-point stencil approximation + # See: https://en.wikipedia.org/wiki/Five-point_stencil#1D_first_derivative + approx_g = (4 * approx_g1 - approx_g2) / 3 assert_allclose(g, approx_g, rtol=1e-2, atol=1e-8) # 2. Check hessp numerically along the second direction of the gradient From 4e8681bc0c7e4a5e5c2a762b4590b611f45d85de Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 1 Feb 2022 18:33:30 +0100 Subject: [PATCH 42/44] CLN rename to init arg and attribute to base_loss --- sklearn/linear_model/_linear_loss.py | 42 +++++++++---------- sklearn/linear_model/_logistic.py | 10 +++-- .../linear_model/tests/test_linear_loss.py | 10 ++--- sklearn/linear_model/tests/test_sag.py | 4 +- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 1c13ac8318d2b..93a7684aea5b6 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -28,13 +28,13 @@ class LinearModelLoss: else: n_dof = n_features - if loss.is_multiclass: + if base_loss.is_multiclass: coef.shape = (n_classes, n_dof) or ravelled (n_classes * n_dof,) else: coef.shape = (n_dof,) The intercept term is at the end of the coef array: - if loss.is_multiclass: + if base_loss.is_multiclass: if coef.shape (n_classes, n_dof): intercept = coef[:, -1] if coef.shape (n_classes * n_dof,) @@ -57,12 +57,12 @@ class LinearModelLoss: Parameters ---------- - _loss : instance of a loss function from sklearn._loss. + base_loss : instance of class BaseLoss from sklearn._loss. fit_intercept : bool """ - def __init__(self, loss, fit_intercept): - self._loss = loss + def __init__(self, base_loss, fit_intercept): + self.base_loss = base_loss self.fit_intercept = fit_intercept def _w_intercept_raw(self, coef, X): @@ -87,7 +87,7 @@ def _w_intercept_raw(self, coef, X): raw_prediction : ndarray of shape (n_samples,) or \ (n_samples, n_classes) """ - if not self._loss.is_multiclass: + if not self.base_loss.is_multiclass: if self.fit_intercept: intercept = coef[-1] weights = coef[:-1] @@ -98,7 +98,7 @@ def _w_intercept_raw(self, coef, X): else: # reshape to (n_classes, n_dof) if coef.ndim == 1: - weights = coef.reshape((self._loss.n_classes, -1), order="F") + weights = coef.reshape((self.base_loss.n_classes, -1), order="F") else: weights = coef if self.fit_intercept: @@ -138,7 +138,7 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) """ weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - loss = self._loss.loss( + loss = self.base_loss.loss( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -152,7 +152,7 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) def loss_gradient( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 ): - """Computes the sum/average of loss and gradient. + """Computes the sum of loss and gradient w.r.t. coef. Parameters ---------- @@ -180,11 +180,11 @@ def loss_gradient( gradient : ndarray of shape coef.shape The gradient of the loss. """ - n_features, n_classes = X.shape[1], self._loss.n_classes + n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - loss, grad_per_sample = self._loss.loss_gradient( + loss, grad_per_sample = self.base_loss.loss_gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -192,7 +192,7 @@ def loss_gradient( ) loss = loss.sum() - if not self._loss.is_multiclass: + if not self.base_loss.is_multiclass: loss += 0.5 * l2_reg_strength * (weights @ weights) grad = np.empty_like(coef, dtype=X.dtype) grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights @@ -213,7 +213,7 @@ def loss_gradient( def gradient( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 ): - """Computes the gradient. + """Computes the gradient w.r.t. coef. Parameters ---------- @@ -238,18 +238,18 @@ def gradient( gradient : ndarray of shape coef.shape The gradient of the loss. """ - n_features, n_classes = X.shape[1], self._loss.n_classes + n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - grad_per_sample = self._loss.gradient( + grad_per_sample = self.base_loss.gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, n_threads=n_threads, ) - if not self._loss.is_multiclass: + if not self.base_loss.is_multiclass: grad = np.empty_like(coef, dtype=X.dtype) grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights if self.fit_intercept: @@ -269,7 +269,7 @@ def gradient( def gradient_hessian_product( self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 ): - """Computes gradient and hessp (hessian product function). + """Computes gradient and hessp (hessian product function) w.r.t. coef. Parameters ---------- @@ -298,12 +298,12 @@ def gradient_hessian_product( Function that takes in a vector input of shape of gradient and and returns matrix-vector product with hessian. """ - (n_samples, n_features), n_classes = X.shape, self._loss.n_classes + (n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - if not self._loss.is_multiclass: - gradient, hessian = self._loss.gradient_hessian( + if not self.base_loss.is_multiclass: + gradient, hessian = self.base_loss.gradient_hessian( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, @@ -350,7 +350,7 @@ def hessp(s): # HalfMultinomialLoss computes only the diagonal part of the hessian, i.e. # diagonal in the classes. Here, we want the matrix-vector product of the # full hessian. Therefore, we call gradient_proba. - gradient, proba = self._loss.gradient_proba( + gradient, proba = self.base_loss.gradient_proba( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 1d1b66db92d63..648e08f981e0a 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -399,7 +399,7 @@ def _logistic_regression_path( # As w0 is F-contiguous, ravel(order="F") also avoids a copy. w0 = w0.ravel(order="F") loss = LinearModelLoss( - loss=HalfMultinomialLoss(n_classes=classes.size), + base_loss=HalfMultinomialLoss(n_classes=classes.size), fit_intercept=fit_intercept, ) target = Y_multi @@ -413,10 +413,14 @@ def _logistic_regression_path( else: target = y_bin if solver == "lbfgs": - loss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + loss = LinearModelLoss( + base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept + ) func = loss.loss_gradient elif solver == "newton-cg": - loss = LinearModelLoss(loss=HalfBinomialLoss(), fit_intercept=fit_intercept) + loss = LinearModelLoss( + base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept + ) func = loss.loss grad = loss.gradient hess = loss.gradient_hessian_product # hess = [gradient, hessp] diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py index 61c211dc98f12..a3b5670bb7bc7 100644 --- a/sklearn/linear_model/tests/test_linear_loss.py +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -85,7 +85,7 @@ def test_loss_gradients_are_the_same( base_loss, fit_intercept, sample_weight, l2_reg_strength ): """Test that loss and gradient are the same across different functions.""" - loss = LinearModelLoss(loss=base_loss(), fit_intercept=fit_intercept) + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=fit_intercept) X, y, coef = random_X_y_coef( linear_model_loss=loss, n_samples=10, n_features=5, seed=42 ) @@ -141,8 +141,8 @@ def test_loss_gradients_hessp_intercept( base_loss, sample_weight, l2_reg_strength, X_sparse ): """Test that loss and gradient handle intercept correctly.""" - loss = LinearModelLoss(loss=base_loss(), fit_intercept=False) - loss_inter = LinearModelLoss(loss=base_loss(), fit_intercept=True) + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=False) + loss_inter = LinearModelLoss(base_loss=base_loss(), fit_intercept=True) n_samples, n_features = 10, 5 X, y, coef = random_X_y_coef( linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 @@ -201,7 +201,7 @@ def test_gradients_hessians_numerically( Gradient should equal the numerical derivatives of the loss function. Hessians should equal the numerical derivatives of gradients. """ - loss = LinearModelLoss(loss=base_loss(), fit_intercept=fit_intercept) + loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=fit_intercept) n_samples, n_features = 10, 5 X, y, coef = random_X_y_coef( linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 @@ -277,7 +277,7 @@ def test_gradients_hessians_numerically( @pytest.mark.parametrize("fit_intercept", [False, True]) def test_multinomial_coef_shape(fit_intercept): """Test that multinomial LinearModelLoss respects shape of coef.""" - loss = LinearModelLoss(loss=HalfMultinomialLoss(), fit_intercept=fit_intercept) + loss = LinearModelLoss(base_loss=HalfMultinomialLoss(), fit_intercept=fit_intercept) n_samples, n_features = 10, 5 X, y, coef = random_X_y_coef( linear_model_loss=loss, n_samples=n_samples, n_features=n_features, seed=42 diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index de85638f66b06..d3a27c4088ab7 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -935,7 +935,7 @@ def test_multinomial_loss(): ) # compute loss and gradient like in multinomial LogisticRegression loss = LinearModelLoss( - loss=HalfMultinomialLoss(n_classes=n_classes), + base_loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) weights_intercept = np.vstack((weights, intercept)).T @@ -969,7 +969,7 @@ def test_multinomial_loss_ground_truth(): grad_1 = np.dot(X.T, diff) loss = LinearModelLoss( - loss=HalfMultinomialLoss(n_classes=n_classes), + base_loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=True, ) weights_intercept = np.vstack((weights, intercept)).T From ee066bbcb1fb42605364f09da827997c174eeafc Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 9 Feb 2022 14:04:45 +0100 Subject: [PATCH 43/44] CLN apply review suggestion Co-authored-by: Alexandre Gramfort --- sklearn/linear_model/tests/test_linear_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py index a3b5670bb7bc7..74b02d4da125b 100644 --- a/sklearn/linear_model/tests/test_linear_loss.py +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -54,7 +54,7 @@ def random_X_y_coef( # See https://stackoverflow.com/a/34190035/16761084 def choice_vectorized(items, p): s = p.cumsum(axis=1) - r = np.random.rand(p.shape[0])[:, None] + r = rng.rand(p.shape[0])[:, None] k = (s < r).sum(axis=1) return items[k] From 190c264f11436fa4ef0374c8278a449491768c2b Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 9 Feb 2022 20:06:59 +0100 Subject: [PATCH 44/44] CLN base_loss instead of _loss attribute --- sklearn/linear_model/tests/test_linear_loss.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/linear_model/tests/test_linear_loss.py b/sklearn/linear_model/tests/test_linear_loss.py index 74b02d4da125b..d4e20ad69ca8a 100644 --- a/sklearn/linear_model/tests/test_linear_loss.py +++ b/sklearn/linear_model/tests/test_linear_loss.py @@ -36,8 +36,8 @@ def random_X_y_coef( random_state=rng, ) - if linear_model_loss._loss.is_multiclass: - n_classes = linear_model_loss._loss.n_classes + if linear_model_loss.base_loss.is_multiclass: + n_classes = linear_model_loss.base_loss.n_classes coef = np.empty((n_classes, n_dof)) coef.flat[:] = rng.uniform( low=coef_bound[0], @@ -48,7 +48,7 @@ def random_X_y_coef( raw_prediction = X @ coef[:, :-1].T + coef[:, -1] else: raw_prediction = X @ coef.T - proba = linear_model_loss._loss.link.inverse(raw_prediction) + proba = linear_model_loss.base_loss.link.inverse(raw_prediction) # y = rng.choice(np.arange(n_classes), p=proba) does not work. # See https://stackoverflow.com/a/34190035/16761084 @@ -70,7 +70,7 @@ def choice_vectorized(items, p): raw_prediction = X @ coef[:-1] + coef[-1] else: raw_prediction = X @ coef - y = linear_model_loss._loss.link.inverse( + y = linear_model_loss.base_loss.link.inverse( raw_prediction + rng.uniform(low=-1, high=1, size=n_samples) ) @@ -304,5 +304,5 @@ def test_multinomial_coef_shape(fit_intercept): assert_allclose(g_r, g1_r) assert_allclose(g_r, g2_r) - assert_allclose(g, g_r.reshape(loss._loss.n_classes, -1, order="F")) - assert_allclose(h, h_r.reshape(loss._loss.n_classes, -1, order="F")) + assert_allclose(g, g_r.reshape(loss.base_loss.n_classes, -1, order="F")) + assert_allclose(h, h_r.reshape(loss.base_loss.n_classes, -1, order="F"))