diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 88b171eaf3ab0..67bbdc45f9782 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -15,18 +15,22 @@ from math import ceil, floor, factorial import numbers from abc import ABCMeta, abstractmethod +import time +import operator import numpy as np import scipy.sparse as sp from .base import is_classifier, clone from .utils import check_arrays, check_random_state, safe_mask +from .utils.validation import _num_samples from .utils.fixes import unique -from .externals.joblib import Parallel, delayed +from .externals.joblib import Parallel, delayed, logger from .externals.six import string_types, with_metaclass from .metrics import SCORERS, Scorer __all__ = ['Bootstrap', + 'CVScorer', 'KFold', 'LeaveOneLabelOut', 'LeaveOneOut', @@ -1024,13 +1028,74 @@ def __len__(self): ############################################################################## -def _cross_val_score(estimator, X, y, scorer, train, test, verbose, - fit_params): - """Inner loop for cross validation""" - n_samples = X.shape[0] if sp.issparse(X) else len(X) +def _fit_fold(estimator, X, y, fold_index, train, test, scoring, verbose, + parameters=None, candidate_index=None, fit_params=None): + """Run fit on one set of parameters. + + Parameters + ---------- + estimator : estimator object + This estimator will be cloned and then fitted. + + X : array-like, sparse matrix or list + Input data. + + y : array-like or None + Targets for input data. + + train : ndarray, dtype int or bool + Boolean mask or indices for training set. + + test : ndarray, dtype int or bool + Boolean mask or indices for test set. + + scoring : callable or None. + If provided must be a scoring object / function with signature + ``scoring(estimator, X, y)``. + + verbose : int + Verbosity level. + + parameters : dict + Parameters to be set on estimator for this fold. + + **fit_params : kwargs + Additional parameter passed to the fit function of the estimator. + + + Returns + ------- + results : dict of string to any + An extensible storage of fold results including the following keys: + + ``'test_score'`` : float + The estimator's score on the test set. + ``'test_n_samples'`` : int + The number of samples in the test set. + """ + if verbose > 1: + start_time = time.time() + if parameters is None: + msg = '' + else: + msg = '%s' % (', '.join('%s=%s' % (k, v) + for k, v in parameters.items())) + print("Fitting fold %s %s" % (msg, (64 - len(msg)) * '.')) + + n_samples = _num_samples(X) + + # Adapt fit_params to train portion only + if fit_params is None: + fit_params = {} fit_params = dict([(k, np.asarray(v)[train] if hasattr(v, '__len__') and len(v) == n_samples else v) for k, v in fit_params.items()]) + + if hasattr(estimator, 'kernel') and callable(estimator.kernel): + # cannot compute the kernel values with custom function + raise ValueError("Cannot use a custom kernel function. " + "Precompute the kernel matrix instead.") + if not hasattr(X, "shape"): if getattr(estimator, "_pairwise", False): raise ValueError("Precomputed kernels or affinity matrices have " @@ -1048,23 +1113,193 @@ def _cross_val_score(estimator, X, y, scorer, train, test, verbose, X_train = X[safe_mask(X, train)] X_test = X[safe_mask(X, test)] - if y is None: - y_train = None - y_test = None - else: - y_train = y[train] - y_test = y[test] - estimator.fit(X_train, y_train, **fit_params) - if scorer is None: - score = estimator.score(X_test, y_test) + # update parameters of the classifier after a copy of its base structure + if parameters is not None: + estimator = clone(estimator) + estimator.set_params(**parameters) + + if scoring is None: + scoring = lambda estimator, *args: estimator.score(*args) + + if y is not None: + y_test = y[safe_mask(y, test)] + y_train = y[safe_mask(y, train)] + fit_args = (X_train, y_train) + score_args = (X_test, y_test) else: - score = scorer(estimator, X_test, y_test) - if not isinstance(score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(score), type(score))) + fit_args = (X_train,) + score_args = (X_test,) + + # do actual fitting + estimator.fit(*fit_args, **fit_params) + test_score = scoring(estimator, *score_args) + + if not isinstance(test_score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s)" + " instead." % (str(test_score), type(test_score))) + + if verbose > 2: + msg += ", score=%f" % test_score if verbose > 1: - print("score: %f" % score) - return score + end_msg = "%s -%s" % (msg, + logger.short_format_time(time.time() - + start_time)) + print("Fitting fold %s %s" % ((64 - len(end_msg)) * '.', end_msg)) + return { + 'candidate_index': candidate_index, + 'fold_index': fold_index, + 'parameters': parameters, + 'test_score': test_score, + 'train_n_samples': _num_samples(X_train), + 'test_n_samples': _num_samples(X_test), + } + + +class CVScorer(object): + """Parallelized cross-validation for a given estimator and dataset + + Parameters + ---------- + cv : integer or cross-validation generator, optional + A cross-validation generator or number of folds (default 3). Folds will + be stratified if the estimator is a classifier. + + scoring : string or callable, optional + Either one of either a string ("zero_one", "f1", "roc_auc", ... for + classification, "mse", "r2", ... for regression) or a callable. + See 'Scoring objects' in the model evaluation section of the user guide + for details. + + iid : boolean, optional + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. + + n_jobs : integer, optional + The number of jobs to run in parallel (default 1). -1 means 'all CPUs'. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + - None, in which case all the jobs are immediatly + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + """ + + def __init__(self, cv=None, scoring=None, iid=True, + n_jobs=1, pre_dispatch='2*n_jobs', verbose=0, fit_params=None, + score_func=None): + + self.cv = cv + self.iid = iid + self.verbose = verbose + self.fit_params = fit_params + + if score_func is not None: + warnings.warn("Passing function as ``score_func`` is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.", stacklevel=2) + self.scoring = Scorer(score_func) + elif isinstance(scoring, string_types): + self.scoring = SCORERS[scoring] + else: + self.scoring = scoring + + self.parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) + + def __call__(self, estimator, X, y=None): + """Cross-validate the estimator on the given data. + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional + The target variable to try to predict in the case of + supervised learning. + + Returns + ------- + fold_results : list of dicts + For each cross-validation fold, this provides fields: + * ``test_score``, the score for this fold + * ``fold_index``, the index of the fold + * ``train_n_samples``, the number of samples in training + * ``test_n_samples``, the number of samples in testing + """ + return sorted(self.search([{}], estimator, X, y), + key=operator.itemgetter('fold_index')) + + def search(self, candidates, estimator, X, y=None): + """Cross-validate the estimator for candidate parameter settings. + + Parameters + ---------- + candidates : iterable of dicts + The estimator will be cloned and have these parameters + set for each candidate. + + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional + The target variable to try to predict in the case of + supervised learning. + + Returns + ------- + fold_results : iterable of dicts + For each cross-validation fold, this provides fields: + * ``test_score``, the score for this fold + * ``parameters``, a dictionary of parameter settings + * ``candidate_index``, the index of the candidate + * ``fold_index``, the index of the fold + * ``train_n_samples``, the number of samples in training + * ``test_n_samples``, the number of samples in testing + """ + X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) + cv = check_cv(self.cv, X, y, classifier=is_classifier(estimator)) + cv = list(cv) + if self.scoring is None and not hasattr(estimator, 'score'): + raise TypeError( + "If no scoring is specified, the estimator passed " + "should have a 'score' method. The estimator %s " + "does not." % estimator) + + out = self.parallel( + delayed(_fit_fold)( + estimator, X, y, + fold_index, train=train, test=test, + parameters=parameters, candidate_index=candidate_index, + scoring=self.scoring, verbose=self.verbose, + fit_params=self.fit_params + ) + for candidate_index, parameters in enumerate(candidates) + for fold_index, (train, test) in enumerate(cv)) + + # return iterator for future compatibility + return iter(out) def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, @@ -1090,20 +1325,17 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, See 'Scoring objects' in the model evaluation section of the user guide for details. - cv : cross-validation generator, optional - A cross-validation generator. If None, a 3-fold cross - validation is used or 3-fold stratified cross-validation - when y is supplied and estimator is a classifier. - - n_jobs : integer, optional - The number of CPUs to use to do the computation. -1 means - 'all CPUs'. + cv : integer or cross-validation generator, optional + A cross-validation generator or number of folds (default 3). Folds will + be stratified if the estimator is a classifier. - verbose : integer, optional - The verbosity level. + iid : boolean, optional + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. - fit_params : dict, optional - Parameters to pass to the fit method of the estimator. + n_jobs : integer, optional + The number of jobs to run in parallel (default 1). -1 means 'all CPUs'. pre_dispatch : int, or string, optional Controls the number of jobs that get dispatched during parallel @@ -1122,37 +1354,22 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + Returns ------- - scores : array of float, shape=(len(list(cv)),) + scores : array of float, shape=[len(cv)] Array of scores of the estimator for each run of the cross validation. """ - X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) - cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - if score_func is not None: - warnings.warn("Passing function as ``score_func`` is " - "deprecated and will be removed in 0.15. " - "Either use strings or score objects.", stacklevel=2) - scorer = Scorer(score_func) - elif isinstance(scoring, string_types): - scorer = SCORERS[scoring] - else: - scorer = scoring - if scorer is None and not hasattr(estimator, 'score'): - raise TypeError( - "If no scoring is specified, the estimator passed " - "should have a 'score' method. The estimator %s " - "does not." % estimator) - # We clone the estimator to make sure that all the folds are - # independent, and that it is pickle-able. - fit_params = fit_params if fit_params is not None else {} - parallel = Parallel(n_jobs=n_jobs, verbose=verbose, - pre_dispatch=pre_dispatch) - scores = parallel( - delayed(_cross_val_score)(clone(estimator), X, y, scorer, train, test, - verbose, fit_params) - for train, test in cv) - return np.array(scores) + cv_score = CVScorer( + scoring=scoring, cv=cv, n_jobs=n_jobs, + verbose=verbose, fit_params=fit_params, score_func=score_func, + pre_dispatch=pre_dispatch) + return [fold['test_score'] for fold in cv_score(estimator, X, y)] def _permutation_test_score(estimator, X, y, cv, scorer): diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 319d7a58f1f58..472dec98a15e6 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -14,20 +14,16 @@ from collections import Mapping, namedtuple from functools import partial, reduce from itertools import product -import numbers import operator -import time import warnings import numpy as np -from .base import BaseEstimator, is_classifier, clone +from .base import BaseEstimator, clone, is_classifier from .base import MetaEstimatorMixin -from .cross_validation import check_cv -from .externals.joblib import Parallel, delayed, logger +from .cross_validation import CVScorer, _fit_fold, check_cv from .externals import six -from .utils import safe_mask, check_random_state -from .utils.validation import _num_samples, check_arrays +from .utils import check_random_state, deprecated from .metrics import SCORERS, Scorer @@ -202,10 +198,13 @@ def __len__(self): return self.n_iter +@deprecated('fit_grid_point is deprecated and will be removed in 0.15.') def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer, verbose, loss_func=None, **fit_params): """Run fit on one set of parameters. + This function is DEPRECATED. Use `cross_validation.fit_fold` instead. + Parameters ---------- X : array-like, sparse matrix or list @@ -248,66 +247,11 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - if verbose > 1: - start_time = time.time() - msg = '%s' % (', '.join('%s=%s' % (k, v) - for k, v in parameters.items())) - print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.')) - - # update parameters of the classifier after a copy of its base structure - clf = clone(base_estimator) - clf.set_params(**parameters) - - if hasattr(base_estimator, 'kernel') and callable(base_estimator.kernel): - # cannot compute the kernel values with custom function - raise ValueError("Cannot use a custom kernel function. " - "Precompute the kernel matrix instead.") - - if not hasattr(X, "shape"): - if getattr(base_estimator, "_pairwise", False): - raise ValueError("Precomputed kernels or affinity matrices have " - "to be passed as arrays or sparse matrices.") - X_train = [X[idx] for idx in train] - X_test = [X[idx] for idx in test] - else: - if getattr(base_estimator, "_pairwise", False): - # X is a precomputed square kernel matrix - if X.shape[0] != X.shape[1]: - raise ValueError("X should be a square kernel matrix") - X_train = X[np.ix_(train, train)] - X_test = X[np.ix_(test, train)] - else: - X_train = X[safe_mask(X, train)] - X_test = X[safe_mask(X, test)] - - if y is not None: - y_test = y[safe_mask(y, test)] - y_train = y[safe_mask(y, train)] - clf.fit(X_train, y_train, **fit_params) - - if scorer is not None: - this_score = scorer(clf, X_test, y_test) - else: - this_score = clf.score(X_test, y_test) - else: - clf.fit(X_train, **fit_params) - if scorer is not None: - this_score = scorer(clf, X_test) - else: - this_score = clf.score(X_test) - - if not isinstance(this_score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(this_score), type(this_score))) - - if verbose > 2: - msg += ", score=%f" % this_score - if verbose > 1: - end_msg = "%s -%s" % (msg, - logger.short_format_time(time.time() - - start_time)) - print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - return this_score, parameters, _num_samples(X_test) + res = _fit_fold( + base_estimator, X, y, train, test, scorer, verbose, + loss_func=None, parameters=parameters, fit_params=fit_params + ) + return res['test_score'], parameters, res['test_n_samples'] def _check_param_grid(param_grid): @@ -433,15 +377,9 @@ def _check_estimator(self): def _fit(self, X, y, parameter_iterable, **params): """Actual fitting, performing the search over parameters.""" - if params: warnings.warn("Passing additional parameters to GridSearchCV " "is ignored! The option will be removed in 0.15.") - estimator = self.estimator - cv = self.cv - - n_samples = _num_samples(X) - X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr') if self.loss_func is not None: warnings.warn("Passing a loss function is " @@ -459,55 +397,42 @@ def _fit(self, X, y, parameter_iterable, **params): scorer = SCORERS[self.scoring] else: scorer = self.scoring - self.scorer_ = scorer - if y is not None: - if len(y) != n_samples: - raise ValueError('Target variable (y) has a different number ' - 'of samples (%i) than data (X: %i samples)' - % (len(y), n_samples)) - y = np.asarray(y) - cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - - base_estimator = clone(self.estimator) - - pre_dispatch = self.pre_dispatch + cv_scorer = CVScorer(cv=self.cv, scoring=self.scorer_, + iid=self.iid, fit_params=self.fit_params, + n_jobs=self.n_jobs, + pre_dispatch=self.pre_dispatch, + verbose=self.verbose) - out = Parallel( - n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch=pre_dispatch)( - delayed(fit_grid_point)( - X, y, base_estimator, parameters, train, test, scorer, - self.verbose, **self.fit_params) - for parameters in parameter_iterable - for train, test in cv) + out = cv_scorer.search(parameter_iterable, clone(self.estimator), X, y) + out = list(out) + out.sort(key=lambda fold_dict: (fold_dict['candidate_index'], + fold_dict['fold_index'])) - # Out is a list of triplet: score, estimator, n_test_samples n_fits = len(out) + cv = check_cv(self.cv, X, y, classifier=is_classifier(self.estimator)) n_folds = len(cv) - scores = list() cv_scores = list() for grid_start in range(0, n_fits, n_folds): n_test_samples = 0 score = 0 all_scores = [] - for this_score, parameters, this_n_test_samples in \ - out[grid_start:grid_start + n_folds]: + for fold_dict in out[grid_start:grid_start + n_folds]: + this_score = fold_dict['test_score'] all_scores.append(this_score) if self.iid: - this_score *= this_n_test_samples - n_test_samples += this_n_test_samples + this_score *= fold_dict['test_n_samples'] + n_test_samples += fold_dict['test_n_samples'] score += this_score if self.iid: score /= float(n_test_samples) else: score /= float(n_folds) - scores.append((score, parameters)) # TODO: shall we also store the test_fold_sizes? cv_scores.append(_CVScoreTuple( - parameters, + fold_dict['parameters'], score, np.array(all_scores))) # Store the computed scores @@ -524,13 +449,14 @@ def _fit(self, X, y, parameter_iterable, **params): if self.refit: # fit the best estimator using the entire dataset # clone first to work around broken estimators - best_estimator = clone(base_estimator).set_params( + best_estimator = clone(self.estimator).set_params( **best.parameters) if y is not None: best_estimator.fit(X, y, **self.fit_params) else: best_estimator.fit(X, **self.fit_params) self.best_estimator_ = best_estimator + return self @@ -585,9 +511,9 @@ class GridSearchCV(BaseSearchCV): as in '2*n_jobs' iid : boolean, optional - If True, the data is assumed to be identically distributed across - the folds, and the loss minimized is the total loss per sample, - and not the mean loss across the folds. + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. cv : integer or cross-validation generator, optional If an integer is passed, it is the number of folds (default 3).