From 27b80561b3a8abfcab05a7e11e34b2dbfb868c68 Mon Sep 17 00:00:00 2001 From: Vlad Niculae Date: Wed, 16 Jan 2013 01:35:29 +0200 Subject: [PATCH] DOC FIX: multi-target linear model attribute shapes DOC fixes and standardization in least_angle docstrings More standardization in linear model docs --- sklearn/linear_model/coordinate_descent.py | 238 +++++++++++---------- sklearn/linear_model/least_angle.py | 163 ++++++++------ 2 files changed, 213 insertions(+), 188 deletions(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index b3619df264efd..10c6bc8704b81 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -57,73 +57,74 @@ class ElasticNet(LinearModel, RegressorMixin): alpha : float Constant that multiplies the penalty terms. Defaults to 1.0 See the notes for the exact mathematical meaning of this - parameter - alpha = 0 is equivalent to an ordinary least square, solved - by the LinearRegression object in the scikit. For numerical - reasons, using alpha = 0 is with the Lasso object is not advised + parameter. + ``alpha = 0`` is equivalent to an ordinary least square, solved + by the :class:`LinearRegression` object. For numerical + reasons, using ``alpha = 0`` with the Lasso object is not advised and you should prefer the LinearRegression object. l1_ratio : float - The ElasticNet mixing parameter, with 0 <= l1_ratio <= 1. For - l1_ratio = 0 the penalty is an L2 penalty. For l1_ratio = 1 it is an L1 - penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and - L2. + The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For + ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it + is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a + combination of L1 and L2. fit_intercept: bool - Whether the intercept should be estimated or not. If False, the + Whether the intercept should be estimated or not. If ``False``, the data is assumed to be already centered. normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. For sparse input - this option is always True to preserve sparsity. + this option is always ``True`` to preserve sparsity. max_iter: int, optional The maximum number of iterations copy_X : boolean, optional, default False - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. tol: float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. warm_start : bool, optional - When set to True, reuse the solution of the previous call to fit as + When set to ``True``, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. positive: bool, optional - When set to True, forces the coefficients to be positive. + When set to ``True``, forces the coefficients to be positive. Attributes ---------- - `coef_` : array, shape = (n_features,) + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - `sparse_coef_` : scipy.sparse matrix, shape = (n_features, 1) - `sparse_coef_` is a readonly property derived from `coef_` + ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ + (n_targets, n_features) + ``sparse_coef_`` is a readonly property derived from ``coef_`` - `intercept_` : float | array, shape = (n_targets,) + ``intercept_`` : float | array, shape = (n_targets,) independent term in decision function. - `dual_gap_` : float + ``dual_gap_`` : float | array, shape = (n_targets,) the current fit is guaranteed to be epsilon-suboptimal with - epsilon := `dual_gap_` + epsilon := ``dual_gap_`` - `eps_` : float - `eps_` is used to check if the fit converged to the requested - `tol` + ``eps_`` : float | array, shape = (n_targets,) + ``eps_`` is used to check if the fit converged to the requested + ``tol`` Notes ----- To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. """ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, precompute='auto', max_iter=1000, @@ -166,7 +167,7 @@ def fit(self, X, y, Xy=None, coef_init=None): Coordinate descent is an algorithm that considers each column of data at a time hence it will automatically convert the X input - as a fortran contiguous numpy array if necessary. + as a Fortran-contiguous numpy array if necessary. To avoid memory re-allocation it is advised to allocate the initial data in memory directly using that format. @@ -347,15 +348,15 @@ class Lasso(ElasticNet): (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 Technically the Lasso model is optimizing the same objective function as - the Elastic Net with l1_ratio=1.0 (no L2 penalty). + the Elastic Net with ``l1_ratio=1.0`` (no L2 penalty). Parameters ---------- alpha : float, optional - Constant that multiplies the L1 term. Defaults to 1.0 - alpha = 0 is equivalent to an ordinary least square, solved - by the LinearRegression object in the scikit. For numerical - reasons, using alpha = 0 is with the Lasso object is not advised + Constant that multiplies the L1 term. Defaults to 1.0. + ``alpha = 0`` is equivalent to an ordinary least square, solved + by the :class:`LinearRegression` object. For numerical + reasons, using ``alpha = 0`` is with the Lasso object is not advised and you should prefer the LinearRegression object. fit_intercept : boolean @@ -364,52 +365,53 @@ class Lasso(ElasticNet): (e.g. data is expected to be already centered). normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. For sparse input - this option is always True to preserve sparsity. + this option is always ``True`` to preserve sparsity. max_iter: int, optional The maximum number of iterations tol : float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. warm_start : bool, optional When set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. positive : bool, optional - When set to True, forces the coefficients to be positive. + When set to ``True``, forces the coefficients to be positive. Attributes ---------- - `coef_` : array, shape = (n_features,) + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - `sparse_coef_` : scipy.sparse matrix, shape = (n_features, 1) - `sparse_coef_` is a readonly property derived from `coef_` + ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ + (n_targets, n_features) + ``sparse_coef_`` is a readonly property derived from ``coef_`` - `intercept_` : float + ``intercept_`` : float | array, shape = (n_targets,) independent term in decision function. - `dual_gap_` : float + ``dual_gap_`` : float | array, shape = (n_targets,) the current fit is guaranteed to be epsilon-suboptimal with - epsilon := `dual_gap_` + epsilon := ``dual_gap_`` - `eps_` : float - `eps_` is used to check if the fit converged to the requested - `tol` + ``eps_`` : float | array, shape = (n_targets,) + ``eps_`` is used to check if the fit converged to the requested + ``tol`` Examples -------- @@ -438,7 +440,7 @@ class Lasso(ElasticNet): The algorithm used to fit the model is coordinate descent. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. """ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, @@ -467,26 +469,26 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, Parameters ---------- X : ndarray, shape = (n_samples, n_features) - Training data. Pass directly as fortran contiguous data to avoid + Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication y : ndarray, shape = (n_samples,) Target values eps : float, optional - Length of the path. eps=1e-3 means that - alpha_min / alpha_max = 1e-3 + Length of the path. ``eps=1e-3`` means that + ``alpha_min / alpha_max = 1e-3`` n_alphas : int, optional Number of alphas along the regularization path alphas : ndarray, optional List of alphas where to compute the models. - If None alphas are set automatically + If ``None`` alphas are set automatically precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. Xy : array-like, optional @@ -497,10 +499,10 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, Fit or not an intercept normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. verbose : bool or integer Amount of verbosity @@ -518,7 +520,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, for an example. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. See also -------- @@ -550,7 +552,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, Parameters ---------- X : ndarray, shape = (n_samples, n_features) - Training data. Pass directly as fortran contiguous data to avoid + Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication y : ndarray, shape = (n_samples,) @@ -558,11 +560,11 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, l1_ratio : float, optional float between 0 and 1 passed to ElasticNet (scaling between - l1 and l2 penalties). l1_ratio=1 corresponds to the Lasso + l1 and l2 penalties). ``l1_ratio=1`` corresponds to the Lasso eps : float - Length of the path. eps=1e-3 means that - alpha_min / alpha_max = 1e-3 + Length of the path. ``eps=1e-3`` means that + ``alpha_min / alpha_max = 1e-3`` n_alphas : int, optional Number of alphas along the regularization path @@ -573,7 +575,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. Xy : array-like, optional @@ -584,10 +586,10 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, Fit or not an intercept normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. verbose : bool or integer Amount of verbosity @@ -716,7 +718,7 @@ def fit(self, X, y): ---------- X : array-like, shape (n_samples, n_features) - Training data. Pass directly as fortran contiguous data to avoid + Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication y : narray, shape (n_samples,) or (n_samples, n_targets) @@ -821,19 +823,19 @@ class LassoCV(LinearModelCV, RegressorMixin): Parameters ---------- eps : float, optional - Length of the path. eps=1e-3 means that - alpha_min / alpha_max = 1e-3. + Length of the path. ``eps=1e-3`` means that + ``alpha_min / alpha_max = 1e-3``. n_alphas : int, optional Number of alphas along the regularization path alphas : numpy array, optional List of alphas where to compute the models. - If None alphas are set automatically + If ``None`` alphas are set automatically precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter: int, optional @@ -841,33 +843,33 @@ class LassoCV(LinearModelCV, RegressorMixin): tol: float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. cv : integer or crossvalidation generator, optional If an integer is passed, it is the number of fold (default 3). - Specific crossvalidation objects can be passed, see - sklearn.cross_validation module for the list of possible objects + Specific crossvalidation objects can be passed, see the + :mod:`sklearn.cross_validation` module for the list of possible objects. verbose : bool or integer amount of verbosity Attributes ---------- - `alpha_`: float + ``alpha_`` : float The amount of penalization choosen by cross validation - `coef_` : array, shape = (n_features,) + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - `intercept_` : float + ``intercept_`` : float | array, shape = (n_targets,) independent term in decision function. - `mse_path_`: array, shape = (n_alphas, n_folds) + ``mse_path_`` : array, shape = (n_alphas, n_folds) mean square error for the test set on each fold, varying alpha - `alphas_`: numpy array + ``alphas_`` : numpy array The grid of alphas used for fitting Notes @@ -876,7 +878,7 @@ class LassoCV(LinearModelCV, RegressorMixin): for an example. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. See also -------- @@ -908,19 +910,19 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): ---------- l1_ratio : float, optional float between 0 and 1 passed to ElasticNet (scaling between - l1 and l2 penalties). For l1_ratio = 0 - the penalty is an L2 penalty. For l1_ratio = 1 it is an L1 penalty. - For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2 + l1 and l2 penalties). For ``l1_ratio = 0`` + the penalty is an L2 penalty. For ``l1_ratio = 1`` it is an L1 penalty. + For ``0 < l1_ratio < 1``, the penalty is a combination of L1 and L2 This parameter can be a list, in which case the different values are tested by cross-validation and the one giving the best prediction score is used. Note that a good choice of list of values for l1_ratio is often to put more values close to 1 - (i.e. Lasso) and less close to 0 (i.e. Ridge), as in [.1, .5, .7, - .9, .95, .99, 1] + (i.e. Lasso) and less close to 0 (i.e. Ridge), as in ``[.1, .5, .7, + .9, .95, .99, 1]`` eps : float, optional - Length of the path. eps=1e-3 means that - alpha_min / alpha_max = 1e-3. + Length of the path. ``eps=1e-3`` means that + ``alpha_min / alpha_max = 1e-3``. n_alphas : int, optional Number of alphas along the regularization path @@ -931,7 +933,7 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter : int, optional @@ -939,39 +941,39 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): tol : float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. cv : integer or crossvalidation generator, optional If an integer is passed, it is the number of fold (default 3). - Specific crossvalidation objects can be passed, see - sklearn.cross_validation module for the list of possible objects + Specific crossvalidation objects can be passed, see the + :mod:`sklearn.cross_validation` module for the list of possible objects. verbose : bool or integer amount of verbosity n_jobs : integer, optional - Number of CPUs to use during the cross validation. If '-1', use + Number of CPUs to use during the cross validation. If ``-1``, use all the CPUs. Note that this is used only if multiple values for l1_ratio are given. Attributes ---------- - `alpha_` : float + ``alpha_`` : float The amount of penalization choosen by cross validation - `l1_ratio_` : float + ``l1_ratio_`` : float The compromise between l1 and l2 penalization choosen by cross validation - `coef_` : array, shape = (n_features,) + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) Parameter vector (w in the cost function formula), - `intercept_` : float + ``intercept_`` : float | array, shape = (n_targets, n_features) Independent term in the decision function. - `mse_path_` : array, shape = (n_l1_ratio, n_alpha, n_folds) + ``mse_path_`` : array, shape = (n_l1_ratio, n_alpha, n_folds) Mean square error for the test set on each fold, varying l1_ratio and alpha. @@ -981,7 +983,7 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): for an example. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. The parameter l1_ratio corresponds to alpha in the glmnet R package while alpha corresponds to the lambda parameter in glmnet. @@ -1058,7 +1060,7 @@ class MultiTaskElasticNet(Lasso): The ElasticNet mixing parameter, with 0 < l1_ratio <= 1. For l1_ratio = 0 the penalty is an L1/L2 penalty. For l1_ratio = 1 it is an L1 penalty. - For 0 < l1_ratio < 1, the penalty is a combination of L1/L2 and L2. + For ``0 < l1_ratio < 1``, the penalty is a combination of L1/L2 and L2. fit_intercept : boolean whether to calculate the intercept for this model. If set @@ -1066,32 +1068,32 @@ class MultiTaskElasticNet(Lasso): (e.g. data is expected to be already centered). normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. max_iter : int, optional The maximum number of iterations tol : float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. warm_start : bool, optional - When set to True, reuse the solution of the previous call to fit as + When set to ``True``, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. Attributes ---------- - `intercept_` : array, shape = (n_tasks,) + ``intercept_`` : array, shape = (n_tasks,) Independent term in decision function. - `coef_` : array, shape = (n_tasks, n_features) + ``coef_`` : array, shape = (n_tasks, n_features) Parameter vector (W in the cost function formula). If a 1D y is \ - passed in at fit (non multi-task usage), `coef_` is then a 1D array + passed in at fit (non multi-task usage), ``coef_`` is then a 1D array Examples -------- @@ -1117,7 +1119,7 @@ class MultiTaskElasticNet(Lasso): The algorithm used to fit the model is coordinate descent. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. """ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, copy_X=True, max_iter=1000, tol=1e-4, @@ -1153,7 +1155,7 @@ def fit(self, X, y, Xy=None, coef_init=None): Coordinate descent is an algorithm that considers each column of data at a time hence it will automatically convert the X input - as a fortran contiguous numpy array if necessary. + as a Fortran-contiguous numpy array if necessary. To avoid memory re-allocation it is advised to allocate the initial data in memory directly using that format. @@ -1229,30 +1231,30 @@ class MultiTaskLasso(MultiTaskElasticNet): (e.g. data is expected to be already centered). normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. max_iter : int, optional The maximum number of iterations tol : float, optional The tolerance for the optimization: if the updates are - smaller than 'tol', the optimization code checks the + smaller than ``tol``, the optimization code checks the dual gap for optimality and continues until it is smaller - than tol. + than ``tol``. warm_start : bool, optional - When set to True, reuse the solution of the previous call to fit as + When set to ``True``, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. Attributes ---------- - `coef_` : array, shape = (n_tasks, n_features) + ``coef_`` : array, shape = (n_tasks, n_features) parameter vector (W in the cost function formula) - `intercept_` : array, shape = (n_tasks,) + ``intercept_`` : array, shape = (n_tasks,) independent term in decision function. Examples @@ -1277,7 +1279,7 @@ class MultiTaskLasso(MultiTaskElasticNet): The algorithm used to fit the model is coordinate descent. To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a fortran contiguous numpy array. + should be directly passed as a Fortran-contiguous numpy array. """ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, copy_X=True, max_iter=1000, tol=1e-4, warm_start=False): diff --git a/sklearn/linear_model/least_angle.py b/sklearn/linear_model/least_angle.py index a14692fd7b7c7..5bd4d9fb74f70 100644 --- a/sklearn/linear_model/least_angle.py +++ b/sklearn/linear_model/least_angle.py @@ -46,7 +46,7 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, Maximum number of iterations to perform, set to infinity for no limit. Gram : None, 'auto', array, shape: (n_features, n_features), optional - Precomputed Gram matrix (X' * X), if 'auto', the Gram + Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram matrix is precomputed from the given X, if there are more samples than features. @@ -55,8 +55,8 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, regularization parameter alpha parameter in the Lasso. method : {'lar', 'lasso'} - Specifies the returned model. Select 'lar' for Least Angle - Regression, 'lasso' for the Lasso. + Specifies the returned model. Select ``'lar'`` for Least Angle + Regression, ``'lasso'`` for the Lasso. eps : float, optional The machine-precision regularization in the computation of the @@ -64,23 +64,26 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, systems. copy_X : bool - If False, X is overwritten. + If ``False``, ``X`` is overwritten. copy_Gram : bool - If False, Gram is overwritten. + If ``False``, ``Gram`` is overwritten. verbose : int (default=0) Controls output verbosity. Returns -------- - alphas: array, shape: (max_features + 1,) + alphas: array, shape: [n_alphas + 1] Maximum of covariances (in absolute value) at each iteration. + ``n_alphas`` is either ``max_iter``, ``n_features`` or the + number of nodes in the path with ``alpha >= alpha_min``, whichever + is smaller. - active: array, shape (max_features,) + active: array, shape [n_alphas] Indices of active variables at the end of the path. - coefs: array, shape (n_features, max_features + 1) + coefs: array, shape (n_features, n_alphas + 1) Coefficients along the path See also @@ -404,7 +407,7 @@ class Lars(LinearModel, RegressorMixin): Parameters ---------- n_nonzero_coefs : int, optional - Target number of non-zero coefficients. Use np.inf for no limit. + Target number of non-zero coefficients. Use ``np.inf`` for no limit. fit_intercept : boolean Whether to calculate the intercept for this model. If set @@ -415,40 +418,47 @@ class Lars(LinearModel, RegressorMixin): Sets the verbosity amount normalize : boolean, optional - If True, the regressors X are normalized + If ``True``, the regressors X are normalized precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. eps: float, optional The machine-precision regularization in the computation of the Cholesky diagonal factors. Increase this for very ill-conditioned - systems. Unlike the 'tol' parameter in some iterative + systems. Unlike the ``tol`` parameter in some iterative optimization-based algorithms, this parameter does not control the tolerance of the optimization. fit_path : boolean - If True the full path is stored in the `coef_path_` attribute. + If True the full path is stored in the ``coef_path_`` attribute. If you compute the solution for a large problem or many targets, - setting fit_path to False will lead to a speedup, especially + setting ``fit_path`` to ``False`` will lead to a speedup, especially with a small alpha. Attributes ---------- - `coef_path_` : array, shape = [n_features, n_alpha] - The varying values of the coefficients along the path. It is not \ - present if the fit_path parameter is False. + ``alphas_`` : array, shape = [n_alphas + 1] | list of n_targets such arrays + Maximum of covariances (in absolute value) at each iteration. \ + ``n_alphas`` is either ``n_nonzero_coefs`` or ``n_features``, \ + whichever is smaller. + + ``active_`` : list, length: [n_alphas] | list of n_targets such lists + Indices of active variables at the end of the path. + ``coef_path_`` : array, shape = [n_features, n_alphas + 1] | list of n_targets such arrays + The varying values of the coefficients along the path. It is not + present if the ``fit_path`` parameter is ``False``. - `coef_` : array, shape = [n_features] + ``coef_`` : array, shape = [n_features] | [n_targets, n_features] Parameter vector (w in the fomulation formula). - `intercept_` : float + ``intercept_`` : float | array of shape [n_targets] Independent term in decision function. Examples @@ -467,7 +477,6 @@ class Lars(LinearModel, RegressorMixin): lars_path, LarsCV sklearn.decomposition.sparse_encode - http://en.wikipedia.org/wiki/Least_angle_regression """ def __init__(self, fit_intercept=True, verbose=False, normalize=True, precompute='auto', n_nonzero_coefs=500, @@ -501,10 +510,11 @@ def fit(self, X, y, Xy=None): X : array-like, shape = [n_samples, n_features] Training data. - y : array-like, shape = [n_samples] or [n_samples, n_targets] + y : array-like, shape = [n_samples] | [n_samples, n_targets] Target values. - Xy : array-like, shape = [n_samples] or [n_samples, n_targets], optional + Xy : array-like, shape = [n_samples] | [n_samples, n_targets], \ + optional Xy = np.dot(X.T, y) that can be precomputed. It is useful only when the Gram matrix is precomputed. @@ -592,6 +602,10 @@ class LassoLars(Lars): Parameters ---------- + alpha : float, optional, default: 1.0 + Lasso regularization parameter. The regularization path is computed + for alphas greater or equal to this parameter. + fit_intercept : boolean whether to calculate the intercept for this model. If set to false, no intercept will be used in calculations @@ -608,7 +622,7 @@ class LassoLars(Lars): precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter: integer, optional @@ -617,26 +631,35 @@ class LassoLars(Lars): eps: float, optional The machine-precision regularization in the computation of the Cholesky diagonal factors. Increase this for very ill-conditioned - systems. Unlike the 'tol' parameter in some iterative + systems. Unlike the ``tol`` parameter in some iterative optimization-based algorithms, this parameter does not control the tolerance of the optimization. fit_path : boolean - If True the full path is stored in the `coef_path_` attribute. + If ``True`` the full path is stored in the ``coef_path_`` attribute. If you compute the solution for a large problem or many targets, - setting fit_path to False will lead to a speedup, especially + setting ``fit_path`` to ``False`` will lead to a speedup, especially with a small alpha. Attributes ---------- - `coef_path_` : array, shape = [n_features, n_alpha] - The varying values of the coefficients along the path. It is not \ - present if fit_path parameter is False. + ``alphas_`` : array, shape = [n_alphas + 1] | list of n_targets such arrays + Maximum of covariances (in absolute value) at each iteration. \ + ``n_alphas`` is either ``max_iter``, ``n_features``, or the number of \ + nodes in the path with correlation greater than ``alpha``, whichever \ + is smaller. - `coef_` : array, shape = [n_features] + ``active_`` : list, length = [n_alphas] | list of n_targets such lists + Indices of active variables at the end of the path. + + ``coef_path_`` : array, shape = [n_features, n_alphas + 1] | list of n_targets such arrays + The varying values of the coefficients along the path. It is not + present if the ``fit_path`` parameter is ``False``. + + ``coef_`` : array, shape = n_features or n_targets, n_features Parameter vector (w in the fomulation formula). - `intercept_` : float + ``intercept_`` : float or array of shape [n_targets] Independent term in decision function. Examples @@ -660,7 +683,6 @@ class LassoLars(Lars): LassoLarsCV sklearn.decomposition.sparse_encode - http://en.wikipedia.org/wiki/Least_angle_regression """ def __init__(self, alpha=1.0, fit_intercept=True, verbose=False, @@ -698,15 +720,15 @@ def _lars_path_residues(X_train, y_train, X_test, y_test, Gram=None, y_test: array, shape (n_samples) The target variable to compute the residues on Gram: None, 'auto', array, shape: (n_features, n_features), optional - Precomputed Gram matrix (X' * X), if 'auto', the Gram + Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram matrix is precomputed from the given X, if there are more samples than features copy: boolean, optional Whether X_train, X_test, y_train and y_test should be copied; if False, they may be overwritten. method: 'lar' | 'lasso' - Specifies the returned model. Select 'lar' for Least Angle - Regression, 'lasso' for the Lasso. + Specifies the returned model. Select ``'lar'`` for Least Angle + Regression, ``'lasso'`` for the Lasso. verbose: integer, optional Sets the amount of verbosity fit_intercept : boolean @@ -720,24 +742,25 @@ def _lars_path_residues(X_train, y_train, X_test, y_test, Gram=None, eps: float, optional The machine-precision regularization in the computation of the Cholesky diagonal factors. Increase this for very ill-conditioned - systems. Unlike the 'tol' parameter in some iterative + systems. Unlike the ``tol`` parameter in some iterative optimization-based algorithms, this parameter does not control the tolerance of the optimization. Returns -------- - alphas: array, shape: (max_features + 1,) - Maximum of covariances (in absolute value) at each - iteration. + alphas: array, shape: [n_alphas + 1] + Maximum of covariances (in absolute value) at each iteration. + ``n_alphas`` is either ``max_iter`` or ``n_features``, whichever + is smaller. - active: array, shape (max_features,) + active: array, shape [n_alphas] Indices of active variables at the end of the path. - coefs: array, shape (n_features, max_features + 1) + coefs: array, shape [n_features, n_alphas + 1) Coefficients along the path - residues: array, shape (n_features, max_features + 1) + residues: array, shape [n_features, n_alphas + 1] Residues of the prediction on the test data """ if copy: @@ -788,18 +811,18 @@ class LarsCV(Lars): If True, the regressors X are normalized copy_X : boolean, optional, default True - If True, X will be copied; else, it may be overwritten. + If ``True``, X will be copied; else, it may be overwritten. precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter: integer, optional Maximum number of iterations to perform. cv : crossvalidation generator, optional - see sklearn.cross_validation module. If None is passed, default to + see :mod:`sklearn.cross_validation`. If ``None`` is passed, default to a 5-fold strategy max_n_alphas : integer, optional @@ -807,7 +830,7 @@ class LarsCV(Lars): residuals in the cross-validation n_jobs : integer, optional - Number of CPUs to use during the cross validation. If '-1', use + Number of CPUs to use during the cross validation. If ``-1``, use all the CPUs eps: float, optional @@ -818,27 +841,27 @@ class LarsCV(Lars): Attributes ---------- - `coef_` : array, shape = [n_features] + ``coef_`` : array, shape = [n_features] parameter vector (w in the fomulation formula) - `intercept_` : float + ``intercept_`` : float independent term in decision function - `coef_path_`: array, shape = [n_features, n_alpha] + ``coef_path_`` : array, shape = [n_features, n_alphas] the varying values of the coefficients along the path - `alpha_`: float + ``alpha_`` : float the estimated regularization parameter alpha - `alphas_`: array, shape = [n_alpha] + ``alphas_`` : array, shape = [n_alphas] the different values of alpha along the path - `cv_alphas_`: array, shape = [n_cv_alphas] + ``cv_alphas_`` : array, shape = [n_cv_alphas] all the values of alpha along the path for the different folds - `cv_mse_path_`: array, shape = [n_folds, n_cv_alphas] + ``cv_mse_path_`` : array, shape = [n_folds, n_cv_alphas] the mean square error on left-out for each fold along the path - (alpha values given by cv_alphas) + (alpha values given by ``cv_alphas``) See also -------- @@ -970,7 +993,7 @@ class LassoLarsCV(LarsCV): precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter: integer, optional @@ -985,7 +1008,7 @@ class LassoLarsCV(LarsCV): residuals in the cross-validation n_jobs : integer, optional - Number of CPUs to use during the cross validation. If '-1', use + Number of CPUs to use during the cross validation. If ``-1``, use all the CPUs eps: float, optional @@ -998,27 +1021,27 @@ class LassoLarsCV(LarsCV): Attributes ---------- - `coef_` : array, shape = [n_features] + ``coef_`` : array, shape = [n_features] parameter vector (w in the fomulation formula) - `intercept_` : float + ``intercept_`` : float independent term in decision function. - `coef_path_`: array, shape = [n_features, n_alpha] + ``coef_path_`` : array, shape = [n_features, n_alphas] the varying values of the coefficients along the path - `alpha_`: float + ``alpha_`` : float the estimated regularization parameter alpha - `alphas_`: array, shape = [n_alpha] + ``alphas_`` : array, shape = [n_alphas] the different values of alpha along the path - `cv_alphas_`: array, shape = [n_cv_alphas] + ``cv_alphas_`` : array, shape = [n_cv_alphas] all the values of alpha along the path for the different folds - `cv_mse_path_`: array, shape = [n_folds, n_cv_alphas] + ``cv_mse_path_`` : array, shape = [n_folds, n_cv_alphas] the mean square error on left-out for each fold along the path - (alpha values given by cv_alphas) + (alpha values given by ``cv_alphas``) Notes ----- @@ -1074,7 +1097,7 @@ class LassoLarsIC(LassoLars): precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram + calculations. If set to ``'auto'`` let us decide. The Gram matrix can also be passed as argument. max_iter: integer, optional @@ -1084,20 +1107,20 @@ class LassoLarsIC(LassoLars): eps: float, optional The machine-precision regularization in the computation of the Cholesky diagonal factors. Increase this for very ill-conditioned - systems. Unlike the 'tol' parameter in some iterative + systems. Unlike the ``tol`` parameter in some iterative optimization-based algorithms, this parameter does not control the tolerance of the optimization. Attributes ---------- - `coef_` : array, shape = [n_features] + ``coef_`` : array, shape = [n_features] parameter vector (w in the fomulation formula) - `intercept_` : float + ``intercept_`` : float independent term in decision function. - `alpha_` : float + ``alpha_`` : float the alpha parameter chosen by the information criterion Examples