From 9005683d3d862a5738b99b816aa486892ed913d3 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 18 Mar 2013 14:34:55 +1100 Subject: [PATCH 1/9] ENH extensible parameter search results using structured arrays --- sklearn/grid_search.py | 278 ++++++++++++++++++++---------- sklearn/tests/test_grid_search.py | 57 ++++-- 2 files changed, 236 insertions(+), 99 deletions(-) diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 0d6615a6e0408..6fbcd827d6ac2 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -9,7 +9,7 @@ # License: BSD Style. from abc import ABCMeta, abstractmethod -from collections import Mapping, namedtuple +from collections import Mapping from functools import partial, reduce from itertools import product import numbers @@ -23,12 +23,12 @@ from .base import MetaEstimatorMixin from .cross_validation import check_cv from .externals.joblib import Parallel, delayed, logger -from .externals.six import string_types -from .utils import safe_mask, check_random_state +from .externals.six import string_types, iterkeys +from .utils import safe_mask, check_random_state, deprecated from .utils.validation import _num_samples, check_arrays from .metrics import SCORERS, Scorer -__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', +__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_fold', 'fit_grid_point', 'ParameterSampler', 'RandomizedSearchCV'] @@ -189,7 +189,7 @@ def __len__(self): return self.n_iter -def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, +def fit_fold(X, y, base_clf, clf_params, train, test, scorer, verbose, loss_func=None, **fit_params): """Run fit on one set of parameters. @@ -226,15 +226,13 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, Returns ------- - score : float - Score of this parameter setting on given training / test split. + results : dict of string to any + An extensible storage of fold results including the following keys: - estimator : estimator object - Estimator object of type base_clf that was fitted using clf_params - and provided train / test split. - - n_samples_test : int - Number of test samples in this split. + ``'test_score'`` : float + The estimator's score on the test set. + ``'test_n_samples'`` : int + The number of samples in the test set. """ if verbose > 1: start_time = time.time() @@ -268,34 +266,92 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, X_train = X[safe_mask(X, train)] X_test = X[safe_mask(X, test)] + if scorer is None: + scorer = lambda clf, *args: clf.score(*args) + if y is not None: y_test = y[safe_mask(y, test)] y_train = y[safe_mask(y, train)] - clf.fit(X_train, y_train, **fit_params) - - if scorer is not None: - this_score = scorer(clf, X_test, y_test) - else: - this_score = clf.score(X_test, y_test) + fit_args = (X_train, y_train) + score_args = (X_test, y_test) else: - clf.fit(X_train, **fit_params) - if scorer is not None: - this_score = scorer(clf, X_test) - else: - this_score = clf.score(X_test) + fit_args = (X_train,) + score_args = (X_test,) + + # do actual fitting + clf.fit(*fit_args, **fit_params) + test_score = scorer(clf, *score_args) - if not isinstance(this_score, numbers.Number): + if not isinstance(test_score, numbers.Number): raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(this_score), type(this_score))) + " instead." % (str(test_score), type(test_score))) if verbose > 2: - msg += ", score=%f" % this_score + msg += ", score=%f" % test_score if verbose > 1: end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() - start_time)) print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - return this_score, clf_params, _num_samples(X_test) + return { + 'test_score': test_score, + 'test_n_samples': _num_samples(X_test), + } + + +@deprecated('fit_grid_point is deprecated and will be removed in 0.15. ' + 'Use fit_fold instead.') +def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, + verbose, loss_func=None, **fit_params): + """Run fit on one set of parameters. + + This function is DEPRECATED. Use `fit_fold` instead. + + Parameters + ---------- + X : array-like, sparse matrix or list + Input data. + + y : array-like or None + Targets for input data. + + base_clf : estimator object + This estimator will be cloned and then fitted. + + clf_params : dict + Parameters to be set on base_estimator clone for this grid point. + + train : ndarray, dtype int or bool + Boolean mask or indices for training set. + + test : ndarray, dtype int or bool + Boolean mask or indices for test set. + + scorer : callable or None. + If provided must be a scoring object / function with signature + ``scorer(estimator, X, y)``. + + verbose : int + Verbosity level. + + **fit_params : kwargs + Additional parameter passed to the fit function of the estimator. + + + Returns + ------- + score : float + Score of this parameter setting on given training / test split. + + clf_params : dict + The parameters used to train this estimator. + + n_samples_test : int + Number of test samples in this split. + """ + res = fit_fold(X, y, base_clf, clf_params, train, test, scorer, + verbose, loss_func=None, **fit_params) + return res['test_score'], clf_params, res['test_n_samples'] def _check_param_grid(param_grid): @@ -316,11 +372,6 @@ def _check_param_grid(param_grid): "list.") -_CVScoreTuple = namedtuple('_CVScoreTuple', - ('parameters', 'mean_validation_score', - 'cv_validation_scores')) - - class BaseSearchCV(BaseEstimator, MetaEstimatorMixin): """Base class for hyper parameter search with cross-validation. """ @@ -387,6 +438,26 @@ def decision_function(self): def transform(self): return self.best_estimator_.transform + @property + def grid_scores_(self): + warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." + " Use grid_results_ and fold_results_ instead.", DeprecationWarning) + return zip(self.grid_results_['parameters'], + self.grid_results_['test_score'], + self.fold_results_['test_score']) + + @property + def best_score_(self): + if not hasattr(self, 'best_index_'): + raise AttributeError('Call fit() to calculate best_score_') + return self.grid_results_['test_score'][self.best_index_] + + @property + def best_params_(self): + if not hasattr(self, 'best_index_'): + raise AttributeError('Call fit() to calculate best_params_') + return self.grid_results_['parameters'][self.best_index_] + def _check_estimator(self): """Check that estimator can be fitted and score can be computed.""" if (not hasattr(self.estimator, 'fit') or @@ -404,6 +475,36 @@ def _check_estimator(self): "should have a 'score' method. The estimator %s " "does not." % self.estimator) + def _set_methods(self): + """Create predict and predict_proba if present in best estimator.""" + if hasattr(self.best_estimator_, 'predict'): + self.predict = self.best_estimator_.predict + if hasattr(self.best_estimator_, 'predict_proba'): + self.predict_proba = self.best_estimator_.predict_proba + + def _aggregate_scores(self, scores, n_samples): + """Take 2d arrays of scores and samples and calculate weighted + means/sums of each row""" + if self.iid: + scores = scores * n_samples + scores = scores.sum(axis=1) / n_samples.sum(axis=1) + else: + scores = scores.sum(axis=1) / scores.shape[1] + return scores + + def _merge_result_dicts(self, result_dicts): + """ + From a result dict for each fold, produce a single dict with an array + for each key. + For example [[{'score': 1}, {'score': 2}], [{'score': 3}, {'score': 4}]] + -> {'score': np.array([[1, 2], [3, 4]])}""" + # assume keys are same throughout + result_keys = list(iterkeys(result_dicts[0][0])) + arrays = ([[fold_results[key] for fold_results in point] + for point in result_dicts] + for key in result_keys) + return np.rec.fromarrays(arrays, names=result_keys) + def _fit(self, X, y, parameter_iterator, **params): """Actual fitting, performing the search over parameters.""" estimator = self.estimator @@ -446,37 +547,25 @@ def _fit(self, X, y, parameter_iterator, **params): out = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch)( - delayed(fit_grid_point)( + delayed(fit_fold)( X, y, base_clf, clf_params, train, test, scorer, self.verbose, **self.fit_params) for clf_params in parameter_iterator for train, test in cv) - # Out is a list of triplet: score, estimator, n_test_samples n_param_points = len(list(parameter_iterator)) n_fits = len(out) n_folds = n_fits // n_param_points - scores = list() - cv_scores = list() - for grid_start in range(0, n_fits, n_folds): - n_test_samples = 0 - score = 0 - these_points = list() - for this_score, clf_params, this_n_test_samples in \ - out[grid_start:grid_start + n_folds]: - these_points.append(this_score) - if self.iid: - this_score *= this_n_test_samples - n_test_samples += this_n_test_samples - score += this_score - if self.iid: - score /= float(n_test_samples) - else: - score /= float(n_folds) - scores.append((score, clf_params)) - cv_scores.append(these_points) + cv_results = self._merge_result_dicts([ + [fold_results for fold_results in out[start:start + n_folds]] + for start in range(0, n_fits, n_folds) + ]) - cv_scores = np.asarray(cv_scores) + field_defs = [('parameters', 'object'), ('test_score', cv_results['test_score'].dtype)] + grid_results = np.zeros(n_param_points, dtype=field_defs) + grid_results['parameters'] = list(parameter_iterator) + grid_results['test_score'] = self._aggregate_scores( + cv_results['test_score'], cv_results['test_n_samples']) # Note: we do not use max(out) to make ties deterministic even if # comparison on estimator instances is not deterministic @@ -490,30 +579,27 @@ def _fit(self, X, y, parameter_iterator, **params): else: best_score = np.inf - for score, params in scores: + for i, score in enumerate(grid_results['test_score']): if ((score > best_score and greater_is_better) - or (score < best_score and not greater_is_better)): + or (score < best_score + and not greater_is_better)): best_score = score - best_params = params + best_index = i - self.best_params_ = best_params - self.best_score_ = best_score + self.best_index_ = best_index + self.fold_results_ = cv_results + self.grid_results_ = grid_results if self.refit: # fit the best estimator using the entire dataset # clone first to work around broken estimators - best_estimator = clone(base_clf).set_params(**best_params) + best_estimator = clone(base_clf).set_params(**self.best_params_) if y is not None: best_estimator.fit(X, y, **self.fit_params) else: best_estimator.fit(X, **self.fit_params) self.best_estimator_ = best_estimator - # Store the computed scores - self.cv_scores_ = [ - _CVScoreTuple(clf_params, score, all_scores) - for clf_params, (score, _), all_scores - in zip(parameter_iterator, scores, cv_scores)] return self @@ -604,20 +690,27 @@ class GridSearchCV(BaseSearchCV): Attributes ---------- - `cv_scores_` : list of named tuples - Contains scores for all parameter combinations in param_grid. - Each entry corresponds to one parameter setting. - Each named tuple has the attributes: + `grid_results_` : structured array of shape [# param combinations] + For each parameter combination in ``param_grid`` includes these fields: - * ``parameters``, a dict of parameter settings - * ``mean_validation_score``, the mean score over the + * ``parameters``, dict of parameter settings + * ``test_score``, the mean score over the cross-validation folds - * ``cv_validation_scores``, the list of scores for each fold + + `fold_results_` : structured array of shape [# param combinations, # folds] + For each cross-validation fold includes these fields: + + * ``test_score``, the score for this fold + * ``test_n_samples``, the number of samples in testing `best_estimator_` : estimator Estimator that was choosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. + on the left out data. Available only if refit=True. + + `best_index_` : int + The index of the best parameter setting into ``grid_results_`` and + ``fold_results_`` data. `best_score_` : float Score of best_estimator on the left out data. @@ -625,6 +718,10 @@ class GridSearchCV(BaseSearchCV): `best_params_` : dict Parameter setting that gave the best results on the hold out data. + `grid_scores_` : list of tuples (deprecated) + Contains scores for all parameter combinations in ``param_grid``: + each tuple is (parameters, mean score, fold scores). + Notes ------ The parameters selected are those that maximize the score of the left out @@ -659,12 +756,6 @@ def __init__(self, estimator, param_grid, scoring=None, loss_func=None, self.param_grid = param_grid _check_param_grid(param_grid) - @property - def grid_scores_(self): - warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." - " Use cv_scores_ instead.", DeprecationWarning) - return self.cv_scores_ - def fit(self, X, y=None, **params): """Run fit with all sets of parameters. @@ -760,20 +851,27 @@ class RandomizedSearchCV(BaseSearchCV): Attributes ---------- - `cv_scores_` : list of named tuples - Contains scores for all parameter combinations in param_grid. - Each entry corresponds to one parameter setting. - Each named tuple has the attributes: + `grid_results_` : structured array of shape [# param combinations] + For each parameter combination in ``param_grid`` includes these fields: - * ``parameters``, a dict of parameter settings - * ``mean_validation_score``, the mean score over the + * ``parameters``, dict of parameter settings + * ``test_score``, the mean score over the cross-validation folds - * ``cv_validation_scores``, the list of scores for each fold + + `fold_results_` : structured array of shape [# param combinations, # folds] + For each cross-validation fold includes these fields: + + * ``test_score``, the score for this fold + * ``test_n_samples``, the number of samples in testing `best_estimator_` : estimator Estimator that was choosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. + on the left out data. Available only if refit=True. + + `best_index_` : int + The index of the best parameter setting into ``grid_results_`` and + ``fold_results_`` data. `best_score_` : float Score of best_estimator on the left out data. @@ -781,6 +879,10 @@ class RandomizedSearchCV(BaseSearchCV): `best_params_` : dict Parameter setting that gave the best results on the hold out data. + `grid_scores_` : list of tuples (deprecated) + Contains scores for all parameter combinations in ``param_grid``: + each tuple is (parameters, mean score, fold scores). + Notes ----- The parameters selected are those that maximize the score of the left out diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index b06e9e7ed6a26..ab52d19f2b298 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -134,9 +134,11 @@ def test_grid_search(): grid_search.fit(X, y) sys.stdout = old_stdout assert_equal(grid_search.best_estimator_.foo_param, 2) + assert_equal(grid_search.best_params_, {'foo_param': 2}) + assert_equal(grid_search.best_score_, 1.) for i, foo_i in enumerate([1, 2, 3]): - assert_true(grid_search.cv_scores_[i][0] + assert_true(grid_search.grid_results_['parameters'][i] == {'foo_param': foo_i}) # Smoke test the score etc: grid_search.score(X, y) @@ -145,19 +147,48 @@ def test_grid_search(): grid_search.transform(X) -def test_trivial_cv_scores(): +def test_grid_scores(): + """Test that GridSearchCV.grid_scores_ is filled in the correct format""" + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3) + # make sure it selects the smallest parameter in case of ties + old_stdout = sys.stdout + sys.stdout = StringIO() + grid_search.fit(X, y) + sys.stdout = old_stdout + assert_equal(grid_search.best_estimator_.foo_param, 2) + + n_folds = 3 + with warnings.catch_warnings(record=True): + for i, foo_i in enumerate([1, 2, 3]): + assert_true(grid_search.grid_scores_[i][0] + == {'foo_param': foo_i}) + # mean score + assert_almost_equal(grid_search.grid_scores_[i][1], + (1. if foo_i > 1 else 0.)) + # all fold scores + assert_array_equal(grid_search.grid_scores_[i][2], + [1. if foo_i > 1 else 0.] * n_folds) + + +def test_trivial_results(): """Test search over a "grid" with only one point. - Non-regression test: cv_scores_ wouldn't be set by GridSearchCV. + Non-regression test: grid_results_, etc. wouldn't be set by GridSearchCV. """ clf = MockClassifier() grid_search = GridSearchCV(clf, {'foo_param': [1]}) grid_search.fit(X, y) - assert_true(hasattr(grid_search, "cv_scores_")) + # Ensure attributes are set + grid_search.grid_results_ + grid_search.fold_results_ + grid_search.best_index_ random_search = RandomizedSearchCV(clf, {'foo_param': [0]}) random_search.fit(X, y) - assert_true(hasattr(random_search, "cv_scores_")) + grid_search.grid_results_ + grid_search.fold_results_ + grid_search.best_index_ def test_no_refit(): @@ -196,20 +227,22 @@ def test_grid_search_iid(): # once with iid=True (default) grid_search = GridSearchCV(svm, param_grid={'C': [1, 10]}, cv=cv) grid_search.fit(X, y) - _, average_score, scores = grid_search.cv_scores_[0] + scores = grid_search.fold_results_[0]['test_score'] assert_array_almost_equal(scores, [1, 1. / 3.]) # for first split, 1/4 of dataset is in test, for second 3/4. # take weighted average + average_score = grid_search.grid_results_[0]['test_score'] assert_almost_equal(average_score, 1 * 1. / 4. + 1. / 3. * 3. / 4.) # once with iid=False (default) grid_search = GridSearchCV(svm, param_grid={'C': [1, 10]}, cv=cv, iid=False) grid_search.fit(X, y) - _, average_score, scores = grid_search.cv_scores_[0] # scores are the same as above + scores = grid_search.fold_results_[0]['test_score'] assert_array_almost_equal(scores, [1, 1. / 3.]) # averaged score is just mean of scores + average_score = grid_search.grid_results_[0]['test_score'] assert_almost_equal(average_score, np.mean(scores)) @@ -419,7 +452,10 @@ def test_X_as_list(): cv = KFold(n=len(X), n_folds=3) grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) grid_search.fit(X.tolist(), y).score(X, y) - assert_true(hasattr(grid_search, "cv_scores_")) + # Ensure result attributes are set + grid_search.grid_results_ + grid_search.fold_results_ + grid_search.best_index_ def test_unsupervised_grid_search(): @@ -466,7 +502,7 @@ def test_randomized_search(): params = dict(C=distributions.expon()) search = RandomizedSearchCV(LinearSVC(), param_distributions=params) search.fit(X, y) - assert_equal(len(search.cv_scores_), 10) + assert_equal(len(search.grid_results_['test_score']), 10) def test_grid_search_score_consistency(): @@ -479,9 +515,8 @@ def test_grid_search_score_consistency(): grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score) grid_search.fit(X, y) cv = StratifiedKFold(n_folds=3, y=y) - for C, scores in zip(Cs, grid_search.cv_scores_): + for C, scores in zip(Cs, grid_search.fold_results_['test_score']): clf.set_params(C=C) - scores = scores[2] # get the separate runs from grid scores i = 0 for train, test in cv: clf.fit(X[train], y[train]) From 6d969235baef4ac6d84942f2024500fd2e452122 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sun, 7 Apr 2013 18:36:44 +1000 Subject: [PATCH 2/9] EXAMPLE remove examples' uses of *SearchCV.cv_scores_ --- examples/grid_search_digits.py | 6 ++++-- examples/svm/plot_rbf_parameters.py | 8 ++------ examples/svm/plot_svm_scale_c.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/grid_search_digits.py b/examples/grid_search_digits.py index a4914609f9225..09b2de98b9d20 100644 --- a/examples/grid_search_digits.py +++ b/examples/grid_search_digits.py @@ -59,9 +59,11 @@ print() print("Grid scores on development set:") print() - for params, mean_score, scores in clf.cv_scores_: + means = clf.grid_results_['test_score'] + stds = clf.fold_results_['test_score'].std(axis=1) + for params, mean, std in zip(clf.grid_results_['parameters'], means, stds): print("%0.3f (+/-%0.03f) for %r" - % (mean_score, scores.std() / 2, params)) + % (mean, std / 2, params)) print() print("Detailed classification report:") diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index f298ebf01205c..39252d5ab7183 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -105,12 +105,8 @@ pl.axis('tight') # plot the scores of the grid -# cv_scores_ contains parameter settings and scores -score_dict = grid.cv_scores_ - -# We extract just the scores -scores = [x[1] for x in score_dict] -scores = np.array(scores).reshape(len(C_range), len(gamma_range)) +scores = grid.grid_results_['test_score'] +scores = scores.reshape(len(C_range), len(gamma_range)) # draw heatmap of accuracy as a function of gamma and C pl.figure(figsize=(8, 6)) diff --git a/examples/svm/plot_svm_scale_c.py b/examples/svm/plot_svm_scale_c.py index 13fe094e10de2..1c17a531aff15 100644 --- a/examples/svm/plot_svm_scale_c.py +++ b/examples/svm/plot_svm_scale_c.py @@ -131,7 +131,7 @@ cv=ShuffleSplit(n=n_samples, train_size=train_size, n_iter=250, random_state=1)) grid.fit(X, y) - scores = [x[1] for x in grid.cv_scores_] + scores = grid.grid_results_['test_score'] scales = [(1, 'No scaling'), ((n_samples * train_size), '1/n_samples'), From b18a2780f9159e8a74c90e9d51e4bf25757a3d03 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Tue, 9 Apr 2013 16:15:00 +1000 Subject: [PATCH 3/9] REFACTOR/ENH refactor CV scoring from grid_search and cross_validation Provides consinstent and enhanced structured-array result style for non-search CV evaluation. So far only regression-tested. --- sklearn/cross_validation.py | 365 ++++++++++++++++++++++++++++++------ sklearn/grid_search.py | 197 +++---------------- 2 files changed, 339 insertions(+), 223 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 9a44c9f36bc55..b8706b9d7ff56 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -14,18 +14,21 @@ from itertools import combinations from math import ceil, floor, factorial import numbers +import time import numpy as np import scipy.sparse as sp from .base import is_classifier, clone from .utils import check_arrays, check_random_state, safe_mask +from .utils.validation import _num_samples from .utils.fixes import unique -from .externals.joblib import Parallel, delayed -from .externals.six import string_types +from .externals.joblib import Parallel, delayed, logger +from .externals.six import string_types, iterkeys from .metrics import SCORERS, Scorer __all__ = ['Bootstrap', + 'CVEvaluator', 'KFold', 'LeaveOneLabelOut', 'LeaveOneOut', @@ -1038,13 +1041,74 @@ def __len__(self): ############################################################################## -def _cross_val_score(estimator, X, y, scorer, train, test, verbose, - fit_params): - """Inner loop for cross validation""" - n_samples = X.shape[0] if sp.issparse(X) else len(X) +def fit_fold(estimator, X, y, train, test, scorer, + verbose, est_params=None, fit_params=None): + """Run fit on one set of parameters. + + Parameters + ---------- + estimator : estimator object + This estimator will be cloned and then fitted. + + X : array-like, sparse matrix or list + Input data. + + y : array-like or None + Targets for input data. + + train : ndarray, dtype int or bool + Boolean mask or indices for training set. + + test : ndarray, dtype int or bool + Boolean mask or indices for test set. + + scorer : callable or None. + If provided must be a scoring object / function with signature + ``scorer(estimator, X, y)``. + + verbose : int + Verbosity level. + + est_params : dict + Parameters to be set on estimator for this grid point. + + **fit_params : kwargs + Additional parameter passed to the fit function of the estimator. + + + Returns + ------- + results : dict of string to any + An extensible storage of fold results including the following keys: + + ``'test_score'`` : float + The estimator's score on the test set. + ``'test_n_samples'`` : int + The number of samples in the test set. + """ + if verbose > 1: + start_time = time.time() + if est_params is None: + msg = '' + else: + msg = '%s' % (', '.join('%s=%s' % (k, v) + for k, v in est_params.items())) + print("[CVEvaluator]%s %s" % (msg, (64 - len(msg)) * '.')) + + n_samples = _num_samples(X) + + # Adapt fit_params to train portion only + if fit_params is None: + fit_params = {} fit_params = dict([(k, np.asarray(v)[train] if hasattr(v, '__len__') and len(v) == n_samples else v) for k, v in fit_params.items()]) + + if hasattr(estimator, 'kernel') and callable(estimator.kernel): + # cannot compute the kernel values with custom function + raise ValueError("Cannot use a custom kernel function. " + "Precompute the kernel matrix instead.") + if not hasattr(X, "shape"): if getattr(estimator, "_pairwise", False): raise ValueError("Precomputed kernels or affinity matrices have " @@ -1062,27 +1126,45 @@ def _cross_val_score(estimator, X, y, scorer, train, test, verbose, X_train = X[safe_mask(X, train)] X_test = X[safe_mask(X, test)] - if y is None: - y_train = None - y_test = None - else: - y_train = y[train] - y_test = y[test] - estimator.fit(X_train, y_train, **fit_params) + # update parameters of the classifier after a copy of its base structure + if est_params is not None: + estimator = clone(estimator) + estimator.set_params(**est_params) + if scorer is None: - score = estimator.score(X_test, y_test) + scorer = lambda estimator, *args: estimator.score(*args) + + if y is not None: + y_test = y[safe_mask(y, test)] + y_train = y[safe_mask(y, train)] + fit_args = (X_train, y_train) + score_args = (X_test, y_test) else: - score = scorer(estimator, X_test, y_test) - if not isinstance(score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(score), type(score))) + fit_args = (X_train,) + score_args = (X_test,) + + # do actual fitting + estimator.fit(*fit_args, **fit_params) + test_score = scorer(estimator, *score_args) + + if not isinstance(test_score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s)" + " instead." % (str(test_score), type(test_score))) + + if verbose > 2: + msg += ", score=%f" % test_score if verbose > 1: - print("score: %f" % score) - return score + end_msg = "%s -%s" % (msg, + logger.short_format_time(time.time() - + start_time)) + print("[CVEvaluator]%s %s" % ((64 - len(end_msg)) * '.', end_msg)) + return { + 'test_score': test_score, + 'test_n_samples': _num_samples(X_test), + } -def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - verbose=0, fit_params=None, score_func=None): +class CVEvaluator(object): """Evaluate a score by cross-validation Parameters @@ -1103,50 +1185,219 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, See 'Scoring objects' in the model evaluation section of the user guide for details. - cv : cross-validation generator, optional - A cross-validation generator. If None, a 3-fold cross - validation is used or 3-fold stratified cross-validation - when y is supplied and estimator is a classifier. + cv : integer or cross-validation generator, optional + A cross-validation generator or number of stratified folds (default 3). + + iid : boolean, optional + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. n_jobs : integer, optional - The number of CPUs to use to do the computation. -1 means - 'all CPUs'. + The number of jobs to run in parallel (default 1). -1 means 'all CPUs'. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediatly + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' verbose : integer, optional The verbosity level. fit_params : dict, optional Parameters to pass to the fit method of the estimator. - - Returns - ------- - scores : array of float, shape=(len(list(cv)),) - Array of scores of the estimator for each run of the cross validation. """ - X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) - cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - if score_func is not None: - warnings.warn("Passing function as ``score_func`` is " - "deprecated and will be removed in 0.15. " - "Either use strings or score objects.", stacklevel=2) - scorer = Scorer(score_func) - elif isinstance(scoring, string_types): - scorer = SCORERS[scoring] - else: - scorer = scoring - if scorer is None and not hasattr(estimator, 'score'): - raise TypeError( - "If no scoring is specified, the estimator passed " - "should have a 'score' method. The estimator %s " - "does not." % estimator) - # We clone the estimator to make sure that all the folds are - # independent, and that it is pickle-able. - fit_params = fit_params if fit_params is not None else {} - scores = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(_cross_val_score)( - clone(estimator), X, y, scorer, train, test, verbose, fit_params) - for train, test in cv) - return np.array(scores) + + def __init__(self, estimator, X, y=None, scoring=None, cv=None, iid=True, + n_jobs=1, pre_dispatch='2*n_jobs', verbose=0, fit_params=None, + score_func=None): + + X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) + self.cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) + self.n_folds = len(self.cv) + self.iid = iid + + if score_func is not None: + warnings.warn("Passing function as ``score_func`` is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.", stacklevel=2) + scorer = Scorer(score_func) + elif isinstance(scoring, string_types): + scorer = SCORERS[scoring] + else: + scorer = scoring + if scorer is None and not hasattr(estimator, 'score'): + raise TypeError( + "If no scoring is specified, the estimator passed " + "should have a 'score' method. The estimator %s " + "does not." % estimator) + + self.parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) + self.fit_fold_kwargs = { + 'X': X, + 'y': y, + 'estimator': clone(estimator), + 'scorer': scorer, + 'verbose': verbose, + 'fit_params': fit_params, + } + + def calc_means(self, scores, n_samples): + """ + Calculate means of the final dimension of `scores`, weighted by + `n_samples` if `iid` is True. + + Parameters + ---------- + scores : array-like of floats + The scores to aggregate. + + n_samples : array-like of integers with same shape as `scores` + The number of samples considered in calculating each score. + + Returns + ------- + means : ndarray with shape of `scores` except for last dimension + The means of the last dimension of scores. + """ + scores = np.asarray(scores) + n_samples = np.asarray(n_samples) + if self.iid: + scores = scores * n_samples + scores = scores.sum(axis=-1) / n_samples.sum(axis=-1) + else: + scores = scores.sum(axis=-1) / scores.shape[-1] + return scores + + def _format_results(self, out): + # group by params + out = [ + [fold_results for fold_results in out[start:start + self.n_folds]] + for start in range(0, len(out), self.n_folds) + ] + + # dicts to structured arrays (assume keys are same throughout): + keys = sorted(iterkeys(out[0][0])) + arrays = ([[fold_results[key] for fold_results in point] + for point in out] + for key in keys) + out = np.rec.fromarrays(arrays, names=keys) + + # for now, only one mean: + means = np.rec.fromarrays( + [self.calc_means(out['test_score'], out['test_n_samples'])], + names=['test_score'] + ) + + return means, out + + def __call__(self, parameters=None): + """ + Cross-validate the estimator, optionally with the given parameters. + + Parameters + ---------- + parameters : dict or iterable of dicts, optional + If provided, the estimator will be cloned and have these parameters + set. If an iterable of parameter settings is given, cross-validation + is performed for each set of parameters. + + Returns + ------- + means : structured array + This provides fields: + * ``test_score``, the mean test score across folds + The array has one dimension corresponding to parameter settings if + an iterable is provided, and zero dimensions otherwise. + fold_results : structured array + For each cross-validation fold, this provides fields: + * ``test_score``, the score for this fold + * ``test_n_samples``, the number of samples in testing + The first axis indexes parameter settings where an iterable is + provided. + """ + if parameters is None: + # unchanged parameters + param_iter = [None] + out_slice = 0 + elif hasattr(parameters, 'items'): + # one set of parameters + param_iter = [parameters] + out_slice = 0 + else: + # sequence of parameters + param_iter = parameters + out_slice = slice(None) + + out = self.parallel( + delayed(fit_fold)(est_params=est_params, + train=train, test=test, + **self.fit_fold_kwargs) + for est_params in param_iter for train, test in self.cv) + + means, out = self._format_results(out) + + return means[out_slice], out[out_slice] + + def score_folds(self, parameters=None): + """ + Cross-validate the estimator, optionally with the given parameters, + and return the scores for each fold. + + Parameters + ---------- + parameters : dict or iterable of dicts, optional + If provided, the estimator will be cloned and have these parameters + set. If an iterable of parameter settings is given, cross-validation + is performed for each set of parameters. + + Returns + ------- + fold_scores : array of floats + The score for each fold evaluated. The first axis indexes parameter + settings where an iterable is provided. + """ + return self(parameters)[1]['test_score'] + + def score_means(self, parameters=None): + """ + Cross-validate the estimator, optionally with the given parameters, and + return the mean score across folds. + + Parameters + ---------- + parameters : dict or iterable of dicts, optional + If provided, the estimator will be cloned and have these parameters + set. If an iterable of parameter settings is given, cross-validation + is performed for each set of parameters. + + Returns + ------- + means : float, or array of floats + The mean score over folds evaluated, or an array of means given an + iterable of parameter settings. + """ + return self(parameters)[0]['test_score'] + + +def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, + verbose=0, fit_params=None, score_func=None): + cv_eval = CVEvaluator(estimator, X, y, scoring=scoring, cv=cv, n_jobs=n_jobs, + verbose=verbose, fit_params=fit_params, score_func=score_func) + return cv_eval.score_folds() def _permutation_test_score(estimator, X, y, cv, scorer): diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 6fbcd827d6ac2..b891119cccf8b 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -12,23 +12,21 @@ from collections import Mapping from functools import partial, reduce from itertools import product -import numbers import operator -import time import warnings import numpy as np +from numpy.lib import recfunctions -from .base import BaseEstimator, is_classifier, clone +from .base import BaseEstimator, clone from .base import MetaEstimatorMixin -from .cross_validation import check_cv -from .externals.joblib import Parallel, delayed, logger +from .cross_validation import CVEvaluator, fit_fold from .externals.six import string_types, iterkeys -from .utils import safe_mask, check_random_state, deprecated +from .utils import check_random_state, deprecated from .utils.validation import _num_samples, check_arrays from .metrics import SCORERS, Scorer -__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_fold', 'fit_grid_point', +__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', 'ParameterSampler', 'RandomizedSearchCV'] @@ -189,123 +187,15 @@ def __len__(self): return self.n_iter -def fit_fold(X, y, base_clf, clf_params, train, test, scorer, - verbose, loss_func=None, **fit_params): - """Run fit on one set of parameters. - - Parameters - ---------- - X : array-like, sparse matrix or list - Input data. - - y : array-like or None - Targets for input data. - - base_clf : estimator object - This estimator will be cloned and then fitted. - - clf_params : dict - Parameters to be set on base_estimator clone for this grid point. - - train : ndarray, dtype int or bool - Boolean mask or indices for training set. - - test : ndarray, dtype int or bool - Boolean mask or indices for test set. - - scorer : callable or None. - If provided must be a scoring object / function with signature - ``scorer(estimator, X, y)``. - - verbose : int - Verbosity level. - - **fit_params : kwargs - Additional parameter passed to the fit function of the estimator. - - - Returns - ------- - results : dict of string to any - An extensible storage of fold results including the following keys: - - ``'test_score'`` : float - The estimator's score on the test set. - ``'test_n_samples'`` : int - The number of samples in the test set. - """ - if verbose > 1: - start_time = time.time() - msg = '%s' % (', '.join('%s=%s' % (k, v) - for k, v in clf_params.items())) - print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.')) - - # update parameters of the classifier after a copy of its base structure - clf = clone(base_clf) - clf.set_params(**clf_params) - - if hasattr(base_clf, 'kernel') and callable(base_clf.kernel): - # cannot compute the kernel values with custom function - raise ValueError("Cannot use a custom kernel function. " - "Precompute the kernel matrix instead.") - - if not hasattr(X, "shape"): - if getattr(base_clf, "_pairwise", False): - raise ValueError("Precomputed kernels or affinity matrices have " - "to be passed as arrays or sparse matrices.") - X_train = [X[idx] for idx in train] - X_test = [X[idx] for idx in test] - else: - if getattr(base_clf, "_pairwise", False): - # X is a precomputed square kernel matrix - if X.shape[0] != X.shape[1]: - raise ValueError("X should be a square kernel matrix") - X_train = X[np.ix_(train, train)] - X_test = X[np.ix_(test, train)] - else: - X_train = X[safe_mask(X, train)] - X_test = X[safe_mask(X, test)] - - if scorer is None: - scorer = lambda clf, *args: clf.score(*args) - - if y is not None: - y_test = y[safe_mask(y, test)] - y_train = y[safe_mask(y, train)] - fit_args = (X_train, y_train) - score_args = (X_test, y_test) - else: - fit_args = (X_train,) - score_args = (X_test,) - - # do actual fitting - clf.fit(*fit_args, **fit_params) - test_score = scorer(clf, *score_args) - - if not isinstance(test_score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(test_score), type(test_score))) - - if verbose > 2: - msg += ", score=%f" % test_score - if verbose > 1: - end_msg = "%s -%s" % (msg, - logger.short_format_time(time.time() - - start_time)) - print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - return { - 'test_score': test_score, - 'test_n_samples': _num_samples(X_test), - } @deprecated('fit_grid_point is deprecated and will be removed in 0.15. ' - 'Use fit_fold instead.') + 'Use cross_validation.fit_fold instead.') def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, verbose, loss_func=None, **fit_params): """Run fit on one set of parameters. - This function is DEPRECATED. Use `fit_fold` instead. + This function is DEPRECATED. Use `cross_validation.fit_fold` instead. Parameters ---------- @@ -349,8 +239,8 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - res = fit_fold(X, y, base_clf, clf_params, train, test, scorer, - verbose, loss_func=None, **fit_params) + res = fit_fold(base_clf, X, y, train, test, scorer, verbose, + loss_func=None, est_params=clf_params, fit_params=fit_params) return res['test_score'], clf_params, res['test_n_samples'] @@ -482,16 +372,6 @@ def _set_methods(self): if hasattr(self.best_estimator_, 'predict_proba'): self.predict_proba = self.best_estimator_.predict_proba - def _aggregate_scores(self, scores, n_samples): - """Take 2d arrays of scores and samples and calculate weighted - means/sums of each row""" - if self.iid: - scores = scores * n_samples - scores = scores.sum(axis=1) / n_samples.sum(axis=1) - else: - scores = scores.sum(axis=1) / scores.shape[1] - return scores - def _merge_result_dicts(self, result_dicts): """ From a result dict for each fold, produce a single dict with an array @@ -507,11 +387,6 @@ def _merge_result_dicts(self, result_dicts): def _fit(self, X, y, parameter_iterator, **params): """Actual fitting, performing the search over parameters.""" - estimator = self.estimator - cv = self.cv - - n_samples = _num_samples(X) - X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr') if self.loss_func is not None: warnings.warn("Passing a loss function is " @@ -529,43 +404,33 @@ def _fit(self, X, y, parameter_iterator, **params): scorer = SCORERS[self.scoring] else: scorer = self.scoring - self.scorer_ = scorer + n_samples = _num_samples(X) + X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr') if y is not None: if len(y) != n_samples: raise ValueError('Target variable (y) has a different number ' 'of samples (%i) than data (X: %i samples)' % (len(y), n_samples)) y = np.asarray(y) - cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - - base_clf = clone(self.estimator) - - pre_dispatch = self.pre_dispatch - out = Parallel( - n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch=pre_dispatch)( - delayed(fit_fold)( - X, y, base_clf, clf_params, train, test, scorer, - self.verbose, **self.fit_params) for clf_params in - parameter_iterator for train, test in cv) - - n_param_points = len(list(parameter_iterator)) - n_fits = len(out) - n_folds = n_fits // n_param_points - - cv_results = self._merge_result_dicts([ - [fold_results for fold_results in out[start:start + n_folds]] - for start in range(0, n_fits, n_folds) - ]) - - field_defs = [('parameters', 'object'), ('test_score', cv_results['test_score'].dtype)] - grid_results = np.zeros(n_param_points, dtype=field_defs) - grid_results['parameters'] = list(parameter_iterator) - grid_results['test_score'] = self._aggregate_scores( - cv_results['test_score'], cv_results['test_n_samples']) + cv_eval = CVEvaluator(self.estimator, X, y, scoring=self.scorer_, + cv=self.cv, iid=self.iid, fit_params=self.fit_params, + n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch, + verbose=self.verbose) + grid_results, cv_results = cv_eval(parameter_iterator) + + # Append 'parameters' to grid_results + # Broken due to https://github.com/numpy/numpy/issues/2346: + # grid_results = recfunctions.append_fields(grid_results, 'parameters', + # np.asarray(list(parameter_iterator)), usemask=False) + new_grid_results = np.zeros(grid_results.shape, + dtype=grid_results.dtype.descr + [('parameters', 'O')]) + for name in grid_results.dtype.names: + new_grid_results[name] = grid_results[name] + new_grid_results['parameters'] = list(parameter_iterator) + grid_results = new_grid_results # Note: we do not use max(out) to make ties deterministic even if # comparison on estimator instances is not deterministic @@ -593,7 +458,7 @@ def _fit(self, X, y, parameter_iterator, **params): if self.refit: # fit the best estimator using the entire dataset # clone first to work around broken estimators - best_estimator = clone(base_clf).set_params(**self.best_params_) + best_estimator = clone(self.estimator).set_params(**self.best_params_) if y is not None: best_estimator.fit(X, y, **self.fit_params) else: @@ -654,9 +519,9 @@ class GridSearchCV(BaseSearchCV): as in '2*n_jobs' iid : boolean, optional - If True, the data is assumed to be identically distributed across - the folds, and the loss minimized is the total loss per sample, - and not the mean loss across the folds. + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. cv : integer or cross-validation generator, optional If an integer is passed, it is the number of folds (default 3). From c5da116a0801bfb5ebb42e9ac89ce4219ff945e8 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 2 May 2013 09:26:46 +1000 Subject: [PATCH 4/9] COSMIT fix pep8 violations --- sklearn/cross_validation.py | 60 ++++++++++++++++--------------- sklearn/grid_search.py | 41 ++++++++++++--------- sklearn/tests/test_grid_search.py | 12 ++++--- 3 files changed, 63 insertions(+), 50 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index b8706b9d7ff56..e9b330ecce96b 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1042,7 +1042,7 @@ def __len__(self): ############################################################################## def fit_fold(estimator, X, y, train, test, scorer, - verbose, est_params=None, fit_params=None): + verbose, est_params=None, fit_params=None): """Run fit on one set of parameters. Parameters @@ -1159,8 +1159,8 @@ def fit_fold(estimator, X, y, train, test, scorer, start_time)) print("[CVEvaluator]%s %s" % ((64 - len(end_msg)) * '.', end_msg)) return { - 'test_score': test_score, - 'test_n_samples': _num_samples(X_test), + 'test_score': test_score, + 'test_n_samples': _num_samples(X_test), } @@ -1221,8 +1221,8 @@ class CVEvaluator(object): """ def __init__(self, estimator, X, y=None, scoring=None, cv=None, iid=True, - n_jobs=1, pre_dispatch='2*n_jobs', verbose=0, fit_params=None, - score_func=None): + n_jobs=1, pre_dispatch='2*n_jobs', verbose=0, fit_params=None, + score_func=None): X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) self.cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) @@ -1245,14 +1245,14 @@ def __init__(self, estimator, X, y=None, scoring=None, cv=None, iid=True, "does not." % estimator) self.parallel = Parallel(n_jobs=n_jobs, verbose=verbose, - pre_dispatch=pre_dispatch) + pre_dispatch=pre_dispatch) self.fit_fold_kwargs = { - 'X': X, - 'y': y, - 'estimator': clone(estimator), - 'scorer': scorer, - 'verbose': verbose, - 'fit_params': fit_params, + 'X': X, + 'y': y, + 'estimator': clone(estimator), + 'scorer': scorer, + 'verbose': verbose, + 'fit_params': fit_params, } def calc_means(self, scores, n_samples): @@ -1290,16 +1290,16 @@ def _format_results(self, out): ] # dicts to structured arrays (assume keys are same throughout): - keys = sorted(iterkeys(out[0][0])) - arrays = ([[fold_results[key] for fold_results in point] - for point in out] - for key in keys) + keys = sorted(iterkeys(out[0][0])) + arrays = ( + [[fold_results[key] for fold_results in point] for point in out] + for key in keys) out = np.rec.fromarrays(arrays, names=keys) # for now, only one mean: means = np.rec.fromarrays( - [self.calc_means(out['test_score'], out['test_n_samples'])], - names=['test_score'] + [self.calc_means(out['test_score'], out['test_n_samples'])], + names=['test_score'] ) return means, out @@ -1312,8 +1312,8 @@ def __call__(self, parameters=None): ---------- parameters : dict or iterable of dicts, optional If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, cross-validation - is performed for each set of parameters. + set. If an iterable of parameter settings is given, + cross-validation is performed for each set of parameters. Returns ------- @@ -1343,9 +1343,10 @@ def __call__(self, parameters=None): out_slice = slice(None) out = self.parallel( - delayed(fit_fold)(est_params=est_params, - train=train, test=test, - **self.fit_fold_kwargs) + delayed(fit_fold)( + est_params=est_params, train=train, test=test, + **self.fit_fold_kwargs + ) for est_params in param_iter for train, test in self.cv) means, out = self._format_results(out) @@ -1361,8 +1362,8 @@ def score_folds(self, parameters=None): ---------- parameters : dict or iterable of dicts, optional If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, cross-validation - is performed for each set of parameters. + set. If an iterable of parameter settings is given, + cross-validation is performed for each set of parameters. Returns ------- @@ -1381,8 +1382,8 @@ def score_means(self, parameters=None): ---------- parameters : dict or iterable of dicts, optional If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, cross-validation - is performed for each set of parameters. + set. If an iterable of parameter settings is given, + cross-validation is performed for each set of parameters. Returns ------- @@ -1395,8 +1396,9 @@ def score_means(self, parameters=None): def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, score_func=None): - cv_eval = CVEvaluator(estimator, X, y, scoring=scoring, cv=cv, n_jobs=n_jobs, - verbose=verbose, fit_params=fit_params, score_func=score_func) + cv_eval = CVEvaluator( + estimator, X, y, scoring=scoring, cv=cv, n_jobs=n_jobs, + verbose=verbose, fit_params=fit_params, score_func=score_func) return cv_eval.score_folds() diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index b891119cccf8b..f620ebb954cad 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -187,10 +187,8 @@ def __len__(self): return self.n_iter - - @deprecated('fit_grid_point is deprecated and will be removed in 0.15. ' - 'Use cross_validation.fit_fold instead.') + 'Use cross_validation.fit_fold instead.') def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, verbose, loss_func=None, **fit_params): """Run fit on one set of parameters. @@ -239,8 +237,10 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - res = fit_fold(base_clf, X, y, train, test, scorer, verbose, - loss_func=None, est_params=clf_params, fit_params=fit_params) + res = fit_fold( + base_clf, X, y, train, test, scorer, verbose, + loss_func=None, est_params=clf_params, fit_params=fit_params + ) return res['test_score'], clf_params, res['test_n_samples'] @@ -331,7 +331,8 @@ def transform(self): @property def grid_scores_(self): warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." - " Use grid_results_ and fold_results_ instead.", DeprecationWarning) + " Use grid_results_ and fold_results_ instead.", + DeprecationWarning) return zip(self.grid_results_['parameters'], self.grid_results_['test_score'], self.fold_results_['test_score']) @@ -376,13 +377,14 @@ def _merge_result_dicts(self, result_dicts): """ From a result dict for each fold, produce a single dict with an array for each key. - For example [[{'score': 1}, {'score': 2}], [{'score': 3}, {'score': 4}]] + For example [[{'score': 1}, {'score': 2}], + [{'score': 3}, {'score': 4}]] -> {'score': np.array([[1, 2], [3, 4]])}""" # assume keys are same throughout - result_keys = list(iterkeys(result_dicts[0][0])) + result_keys = list(iterkeys(result_dicts[0][0])) arrays = ([[fold_results[key] for fold_results in point] - for point in result_dicts] - for key in result_keys) + for point in result_dicts] + for key in result_keys) return np.rec.fromarrays(arrays, names=result_keys) def _fit(self, X, y, parameter_iterator, **params): @@ -415,18 +417,21 @@ def _fit(self, X, y, parameter_iterator, **params): % (len(y), n_samples)) y = np.asarray(y) - cv_eval = CVEvaluator(self.estimator, X, y, scoring=self.scorer_, - cv=self.cv, iid=self.iid, fit_params=self.fit_params, - n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch, - verbose=self.verbose) + cv_eval = CVEvaluator( + self.estimator, X, y, scoring=self.scorer_, cv=self.cv, + iid=self.iid, fit_params=self.fit_params, n_jobs=self.n_jobs, + pre_dispatch=self.pre_dispatch, verbose=self.verbose + ) grid_results, cv_results = cv_eval(parameter_iterator) # Append 'parameters' to grid_results # Broken due to https://github.com/numpy/numpy/issues/2346: # grid_results = recfunctions.append_fields(grid_results, 'parameters', # np.asarray(list(parameter_iterator)), usemask=False) - new_grid_results = np.zeros(grid_results.shape, - dtype=grid_results.dtype.descr + [('parameters', 'O')]) + new_grid_results = np.zeros( + grid_results.shape, + dtype=grid_results.dtype.descr + [('parameters', 'O')] + ) for name in grid_results.dtype.names: new_grid_results[name] = grid_results[name] new_grid_results['parameters'] = list(parameter_iterator) @@ -458,7 +463,9 @@ def _fit(self, X, y, parameter_iterator, **params): if self.refit: # fit the best estimator using the entire dataset # clone first to work around broken estimators - best_estimator = clone(self.estimator).set_params(**self.best_params_) + best_estimator = clone(self.estimator).set_params( + **self.best_params_ + ) if y is not None: best_estimator.fit(X, y, **self.fit_params) else: diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index ab52d19f2b298..12e53513b0710 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -164,11 +164,15 @@ def test_grid_scores(): assert_true(grid_search.grid_scores_[i][0] == {'foo_param': foo_i}) # mean score - assert_almost_equal(grid_search.grid_scores_[i][1], - (1. if foo_i > 1 else 0.)) + assert_almost_equal( + grid_search.grid_scores_[i][1], + (1. if foo_i > 1 else 0.) + ) # all fold scores - assert_array_equal(grid_search.grid_scores_[i][2], - [1. if foo_i > 1 else 0.] * n_folds) + assert_array_equal( + grid_search.grid_scores_[i][2], + [1. if foo_i > 1 else 0.] * n_folds + ) def test_trivial_results(): From 4a8215ec2115a8811a1a5f99a51654e445c4d31f Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 2 May 2013 09:27:17 +1000 Subject: [PATCH 5/9] COSMIT Remove unused private methods --- sklearn/grid_search.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index f620ebb954cad..0bf46ff83235b 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -366,27 +366,6 @@ def _check_estimator(self): "should have a 'score' method. The estimator %s " "does not." % self.estimator) - def _set_methods(self): - """Create predict and predict_proba if present in best estimator.""" - if hasattr(self.best_estimator_, 'predict'): - self.predict = self.best_estimator_.predict - if hasattr(self.best_estimator_, 'predict_proba'): - self.predict_proba = self.best_estimator_.predict_proba - - def _merge_result_dicts(self, result_dicts): - """ - From a result dict for each fold, produce a single dict with an array - for each key. - For example [[{'score': 1}, {'score': 2}], - [{'score': 3}, {'score': 4}]] - -> {'score': np.array([[1, 2], [3, 4]])}""" - # assume keys are same throughout - result_keys = list(iterkeys(result_dicts[0][0])) - arrays = ([[fold_results[key] for fold_results in point] - for point in result_dicts] - for key in result_keys) - return np.rec.fromarrays(arrays, names=result_keys) - def _fit(self, X, y, parameter_iterator, **params): """Actual fitting, performing the search over parameters.""" From 51ff593e4737b24cf8a3a4417133900531acf1c3 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sat, 11 May 2013 22:14:51 +1000 Subject: [PATCH 6/9] DOC Fix comments for cross_val_score and CVEvaluator also, add iid parameter to cross_val_score --- sklearn/cross_validation.py | 64 ++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index e9b330ecce96b..97e8d787cf83c 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1165,7 +1165,7 @@ def fit_fold(estimator, X, y, train, test, scorer, class CVEvaluator(object): - """Evaluate a score by cross-validation + """Parallelized cross-validation for a given estimator and dataset Parameters ---------- @@ -1186,7 +1186,8 @@ class CVEvaluator(object): for details. cv : integer or cross-validation generator, optional - A cross-validation generator or number of stratified folds (default 3). + A cross-validation generator or number of folds (default 3). Folds will + be stratified if the estimator is a classifier. iid : boolean, optional If True (default), the data is assumed to be identically distributed @@ -1394,8 +1395,63 @@ def score_means(self, parameters=None): return self(parameters)[0]['test_score'] -def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - verbose=0, fit_params=None, score_func=None): +def cross_val_score(estimator, X, y=None, scoring=None, cv=None, iid=True, + n_jobs=1, verbose=0, fit_params=None, score_func=None): + """Evaluate a score by cross-validation + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional + The target variable to try to predict in the case of + supervised learning. + + scoring : string or callable, optional + Either one of either a string ("zero_one", "f1", "roc_auc", ... for + classification, "mse", "r2", ... for regression) or a callable. + See 'Scoring objects' in the model evaluation section of the user guide + for details. + + cv : integer or cross-validation generator, optional + A cross-validation generator or number of folds (default 3). Folds will + be stratified if the estimator is a classifier. + + iid : boolean, optional + If True (default), the data is assumed to be identically distributed + across the folds, and the mean score is the total score per sample, + and not the mean score across the folds. + + n_jobs : integer, optional + The number of jobs to run in parallel (default 1). -1 means 'all CPUs'. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediatly + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + """ cv_eval = CVEvaluator( estimator, X, y, scoring=scoring, cv=cv, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, score_func=score_func) From 66a84631529c8adfc48a83facc6d83c6ec3b9867 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sat, 11 May 2013 22:17:35 +1000 Subject: [PATCH 7/9] Debug messages now print `fit_fold` in place of `CVEvaluator` --- sklearn/cross_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 97e8d787cf83c..49de24ee876a5 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1093,7 +1093,7 @@ def fit_fold(estimator, X, y, train, test, scorer, else: msg = '%s' % (', '.join('%s=%s' % (k, v) for k, v in est_params.items())) - print("[CVEvaluator]%s %s" % (msg, (64 - len(msg)) * '.')) + print("[fit_fold]%s %s" % (msg, (64 - len(msg)) * '.')) n_samples = _num_samples(X) @@ -1157,7 +1157,7 @@ def fit_fold(estimator, X, y, train, test, scorer, end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() - start_time)) - print("[CVEvaluator]%s %s" % ((64 - len(end_msg)) * '.', end_msg)) + print("[fit_fold]%s %s" % ((64 - len(end_msg)) * '.', end_msg)) return { 'test_score': test_score, 'test_n_samples': _num_samples(X_test), From d4c10f6f7f8e19f82b9d65b8e9b6e2362ebedc76 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sat, 11 May 2013 23:45:04 +1000 Subject: [PATCH 8/9] Rename grid_results_ to search_results_ --- examples/grid_search_digits.py | 5 ++-- examples/svm/plot_rbf_parameters.py | 2 +- examples/svm/plot_svm_scale_c.py | 2 +- sklearn/cross_validation.py | 2 +- sklearn/grid_search.py | 45 +++++++++++++++-------------- sklearn/tests/test_grid_search.py | 16 +++++----- 6 files changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/grid_search_digits.py b/examples/grid_search_digits.py index 09b2de98b9d20..8e445847d045b 100644 --- a/examples/grid_search_digits.py +++ b/examples/grid_search_digits.py @@ -59,9 +59,10 @@ print() print("Grid scores on development set:") print() - means = clf.grid_results_['test_score'] + candidates = clf.search_results_['parameters'] + means = clf.search_results_['test_score'] stds = clf.fold_results_['test_score'].std(axis=1) - for params, mean, std in zip(clf.grid_results_['parameters'], means, stds): + for params, mean, std in zip(candidates, means, stds): print("%0.3f (+/-%0.03f) for %r" % (mean, std / 2, params)) print() diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index 39252d5ab7183..1ba5ae375399c 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -105,7 +105,7 @@ pl.axis('tight') # plot the scores of the grid -scores = grid.grid_results_['test_score'] +scores = grid.search_results_['test_score'] scores = scores.reshape(len(C_range), len(gamma_range)) # draw heatmap of accuracy as a function of gamma and C diff --git a/examples/svm/plot_svm_scale_c.py b/examples/svm/plot_svm_scale_c.py index 1c17a531aff15..c675788301d5a 100644 --- a/examples/svm/plot_svm_scale_c.py +++ b/examples/svm/plot_svm_scale_c.py @@ -131,7 +131,7 @@ cv=ShuffleSplit(n=n_samples, train_size=train_size, n_iter=250, random_state=1)) grid.fit(X, y) - scores = grid.grid_results_['test_score'] + scores = grid.search_results_['test_score'] scales = [(1, 'No scaling'), ((n_samples * train_size), '1/n_samples'), diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 49de24ee876a5..5c4ee02fe5975 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1070,7 +1070,7 @@ def fit_fold(estimator, X, y, train, test, scorer, Verbosity level. est_params : dict - Parameters to be set on estimator for this grid point. + Parameters to be set on estimator for this fold. **fit_params : kwargs Additional parameter passed to the fit function of the estimator. diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 0bf46ff83235b..09e8c2fe0e995 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -331,23 +331,23 @@ def transform(self): @property def grid_scores_(self): warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." - " Use grid_results_ and fold_results_ instead.", + " Use search_results_ and fold_results_ instead.", DeprecationWarning) - return zip(self.grid_results_['parameters'], - self.grid_results_['test_score'], + return zip(self.search_results_['parameters'], + self.search_results_['test_score'], self.fold_results_['test_score']) @property def best_score_(self): if not hasattr(self, 'best_index_'): raise AttributeError('Call fit() to calculate best_score_') - return self.grid_results_['test_score'][self.best_index_] + return self.search_results_['test_score'][self.best_index_] @property def best_params_(self): if not hasattr(self, 'best_index_'): raise AttributeError('Call fit() to calculate best_params_') - return self.grid_results_['parameters'][self.best_index_] + return self.search_results_['parameters'][self.best_index_] def _check_estimator(self): """Check that estimator can be fitted and score can be computed.""" @@ -401,20 +401,21 @@ def _fit(self, X, y, parameter_iterator, **params): iid=self.iid, fit_params=self.fit_params, n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch, verbose=self.verbose ) - grid_results, cv_results = cv_eval(parameter_iterator) + search_results, cv_results = cv_eval(parameter_iterator) - # Append 'parameters' to grid_results + # Append 'parameters' to search_results # Broken due to https://github.com/numpy/numpy/issues/2346: - # grid_results = recfunctions.append_fields(grid_results, 'parameters', - # np.asarray(list(parameter_iterator)), usemask=False) - new_grid_results = np.zeros( - grid_results.shape, - dtype=grid_results.dtype.descr + [('parameters', 'O')] + # search_results = recfunctions.append_fields( + # search_results, 'parameters', + # np.asarray(list(parameter_iterator)), usemask=False) + new_search_results = np.zeros( + search_results.shape, + dtype=search_results.dtype.descr + [('parameters', 'O')] ) - for name in grid_results.dtype.names: - new_grid_results[name] = grid_results[name] - new_grid_results['parameters'] = list(parameter_iterator) - grid_results = new_grid_results + for name in search_results.dtype.names: + new_search_results[name] = search_results[name] + new_search_results['parameters'] = list(parameter_iterator) + search_results = new_search_results # Note: we do not use max(out) to make ties deterministic even if # comparison on estimator instances is not deterministic @@ -428,7 +429,7 @@ def _fit(self, X, y, parameter_iterator, **params): else: best_score = np.inf - for i, score in enumerate(grid_results['test_score']): + for i, score in enumerate(search_results['test_score']): if ((score > best_score and greater_is_better) or (score < best_score and not greater_is_better)): @@ -437,7 +438,7 @@ def _fit(self, X, y, parameter_iterator, **params): self.best_index_ = best_index self.fold_results_ = cv_results - self.grid_results_ = grid_results + self.search_results_ = search_results if self.refit: # fit the best estimator using the entire dataset @@ -541,7 +542,7 @@ class GridSearchCV(BaseSearchCV): Attributes ---------- - `grid_results_` : structured array of shape [# param combinations] + `search_results_` : structured array of shape [# param combinations] For each parameter combination in ``param_grid`` includes these fields: * ``parameters``, dict of parameter settings @@ -560,7 +561,7 @@ class GridSearchCV(BaseSearchCV): on the left out data. Available only if refit=True. `best_index_` : int - The index of the best parameter setting into ``grid_results_`` and + The index of the best parameter setting into ``search_results_`` and ``fold_results_`` data. `best_score_` : float @@ -702,7 +703,7 @@ class RandomizedSearchCV(BaseSearchCV): Attributes ---------- - `grid_results_` : structured array of shape [# param combinations] + `search_results_` : structured array of shape [# param combinations] For each parameter combination in ``param_grid`` includes these fields: * ``parameters``, dict of parameter settings @@ -721,7 +722,7 @@ class RandomizedSearchCV(BaseSearchCV): on the left out data. Available only if refit=True. `best_index_` : int - The index of the best parameter setting into ``grid_results_`` and + The index of the best parameter setting into ``search_results_`` and ``fold_results_`` data. `best_score_` : float diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 12e53513b0710..7af334200dfd0 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -138,7 +138,7 @@ def test_grid_search(): assert_equal(grid_search.best_score_, 1.) for i, foo_i in enumerate([1, 2, 3]): - assert_true(grid_search.grid_results_['parameters'][i] + assert_true(grid_search.search_results_['parameters'][i] == {'foo_param': foo_i}) # Smoke test the score etc: grid_search.score(X, y) @@ -178,19 +178,19 @@ def test_grid_scores(): def test_trivial_results(): """Test search over a "grid" with only one point. - Non-regression test: grid_results_, etc. wouldn't be set by GridSearchCV. + Non-regression test: search_results_, etc. wouldn't be set by GridSearchCV. """ clf = MockClassifier() grid_search = GridSearchCV(clf, {'foo_param': [1]}) grid_search.fit(X, y) # Ensure attributes are set - grid_search.grid_results_ + grid_search.search_results_ grid_search.fold_results_ grid_search.best_index_ random_search = RandomizedSearchCV(clf, {'foo_param': [0]}) random_search.fit(X, y) - grid_search.grid_results_ + grid_search.search_results_ grid_search.fold_results_ grid_search.best_index_ @@ -235,7 +235,7 @@ def test_grid_search_iid(): assert_array_almost_equal(scores, [1, 1. / 3.]) # for first split, 1/4 of dataset is in test, for second 3/4. # take weighted average - average_score = grid_search.grid_results_[0]['test_score'] + average_score = grid_search.search_results_[0]['test_score'] assert_almost_equal(average_score, 1 * 1. / 4. + 1. / 3. * 3. / 4.) # once with iid=False (default) @@ -246,7 +246,7 @@ def test_grid_search_iid(): scores = grid_search.fold_results_[0]['test_score'] assert_array_almost_equal(scores, [1, 1. / 3.]) # averaged score is just mean of scores - average_score = grid_search.grid_results_[0]['test_score'] + average_score = grid_search.search_results_[0]['test_score'] assert_almost_equal(average_score, np.mean(scores)) @@ -457,7 +457,7 @@ def test_X_as_list(): grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) grid_search.fit(X.tolist(), y).score(X, y) # Ensure result attributes are set - grid_search.grid_results_ + grid_search.search_results_ grid_search.fold_results_ grid_search.best_index_ @@ -506,7 +506,7 @@ def test_randomized_search(): params = dict(C=distributions.expon()) search = RandomizedSearchCV(LinearSVC(), param_distributions=params) search.fit(X, y) - assert_equal(len(search.grid_results_['test_score']), 10) + assert_equal(len(search.search_results_['test_score']), 10) def test_grid_search_score_consistency(): From 72f228644595d1af667cc6b7ce510db5ed3fcae6 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sat, 8 Jun 2013 20:06:05 +1000 Subject: [PATCH 9/9] Change CVEvaluator to approximate scorer interface --- sklearn/cross_validation.py | 200 ++++++++++++++++-------------------- sklearn/grid_search.py | 19 ++-- 2 files changed, 97 insertions(+), 122 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 5c4ee02fe5975..5a117281d462e 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -28,7 +28,7 @@ from .metrics import SCORERS, Scorer __all__ = ['Bootstrap', - 'CVEvaluator', + 'CVScorer', 'KFold', 'LeaveOneLabelOut', 'LeaveOneOut', @@ -1041,8 +1041,8 @@ def __len__(self): ############################################################################## -def fit_fold(estimator, X, y, train, test, scorer, - verbose, est_params=None, fit_params=None): +def _fit_fold(estimator, X, y, train, test, scoring, + verbose, est_params=None, fit_params=None): """Run fit on one set of parameters. Parameters @@ -1062,9 +1062,9 @@ def fit_fold(estimator, X, y, train, test, scorer, test : ndarray, dtype int or bool Boolean mask or indices for test set. - scorer : callable or None. + scoring : callable or None. If provided must be a scoring object / function with signature - ``scorer(estimator, X, y)``. + ``scoring(estimator, X, y)``. verbose : int Verbosity level. @@ -1093,7 +1093,7 @@ def fit_fold(estimator, X, y, train, test, scorer, else: msg = '%s' % (', '.join('%s=%s' % (k, v) for k, v in est_params.items())) - print("[fit_fold]%s %s" % (msg, (64 - len(msg)) * '.')) + print("Fitting fold %s %s" % (msg, (64 - len(msg)) * '.')) n_samples = _num_samples(X) @@ -1131,8 +1131,8 @@ def fit_fold(estimator, X, y, train, test, scorer, estimator = clone(estimator) estimator.set_params(**est_params) - if scorer is None: - scorer = lambda estimator, *args: estimator.score(*args) + if scoring is None: + scoring = lambda estimator, *args: estimator.score(*args) if y is not None: y_test = y[safe_mask(y, test)] @@ -1145,7 +1145,7 @@ def fit_fold(estimator, X, y, train, test, scorer, # do actual fitting estimator.fit(*fit_args, **fit_params) - test_score = scorer(estimator, *score_args) + test_score = scoring(estimator, *score_args) if not isinstance(test_score, numbers.Number): raise ValueError("scoring must return a number, got %s (%s)" @@ -1157,27 +1157,21 @@ def fit_fold(estimator, X, y, train, test, scorer, end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() - start_time)) - print("[fit_fold]%s %s" % ((64 - len(end_msg)) * '.', end_msg)) + print("Fitting fold %s %s" % ((64 - len(end_msg)) * '.', end_msg)) return { 'test_score': test_score, 'test_n_samples': _num_samples(X_test), } -class CVEvaluator(object): +class CVScorer(object): """Parallelized cross-validation for a given estimator and dataset Parameters ---------- - estimator : estimator object implementing 'fit' - The object to use to fit the data. - - X : array-like of shape at least 2D - The data to fit. - - y : array-like, optional - The target variable to try to predict in the case of - supervised learning. + cv : integer or cross-validation generator, optional + A cross-validation generator or number of folds (default 3). Folds will + be stratified if the estimator is a classifier. scoring : string or callable, optional Either one of either a string ("zero_one", "f1", "roc_auc", ... for @@ -1185,10 +1179,6 @@ class CVEvaluator(object): See 'Scoring objects' in the model evaluation section of the user guide for details. - cv : integer or cross-validation generator, optional - A cross-validation generator or number of folds (default 3). Folds will - be stratified if the estimator is a classifier. - iid : boolean, optional If True (default), the data is assumed to be identically distributed across the folds, and the mean score is the total score per sample, @@ -1221,42 +1211,29 @@ class CVEvaluator(object): Parameters to pass to the fit method of the estimator. """ - def __init__(self, estimator, X, y=None, scoring=None, cv=None, iid=True, + def __init__(self, cv=None, scoring=None, iid=True, n_jobs=1, pre_dispatch='2*n_jobs', verbose=0, fit_params=None, score_func=None): - X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) - self.cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - self.n_folds = len(self.cv) + self.cv = cv self.iid = iid + self.verbose = verbose + self.fit_params = fit_params if score_func is not None: warnings.warn("Passing function as ``score_func`` is " "deprecated and will be removed in 0.15. " "Either use strings or score objects.", stacklevel=2) - scorer = Scorer(score_func) + self.scoring = Scorer(score_func) elif isinstance(scoring, string_types): - scorer = SCORERS[scoring] + self.scoring = SCORERS[scoring] else: - scorer = scoring - if scorer is None and not hasattr(estimator, 'score'): - raise TypeError( - "If no scoring is specified, the estimator passed " - "should have a 'score' method. The estimator %s " - "does not." % estimator) + self.scoring = scoring self.parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) - self.fit_fold_kwargs = { - 'X': X, - 'y': y, - 'estimator': clone(estimator), - 'scorer': scorer, - 'verbose': verbose, - 'fit_params': fit_params, - } - - def calc_means(self, scores, n_samples): + + def _calc_means(self, scores, n_samples): """ Calculate means of the final dimension of `scores`, weighted by `n_samples` if `iid` is True. @@ -1283,11 +1260,11 @@ def calc_means(self, scores, n_samples): scores = scores.sum(axis=-1) / scores.shape[-1] return scores - def _format_results(self, out): + def _format_results(self, out, n_folds): # group by params out = [ - [fold_results for fold_results in out[start:start + self.n_folds]] - for start in range(0, len(out), self.n_folds) + [fold_results for fold_results in out[start:start + n_folds]] + for start in range(0, len(out), n_folds) ] # dicts to structured arrays (assume keys are same throughout): @@ -1299,22 +1276,26 @@ def _format_results(self, out): # for now, only one mean: means = np.rec.fromarrays( - [self.calc_means(out['test_score'], out['test_n_samples'])], + [self._calc_means(out['test_score'], out['test_n_samples'])], names=['test_score'] ) return means, out - def __call__(self, parameters=None): - """ - Cross-validate the estimator, optionally with the given parameters. + def __call__(self, estimator, X, y=None): + """Cross-validate the estimator on the given data. Parameters ---------- - parameters : dict or iterable of dicts, optional - If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, - cross-validation is performed for each set of parameters. + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional + The target variable to try to predict in the case of + supervised learning. Returns ------- @@ -1330,69 +1311,59 @@ def __call__(self, parameters=None): The first axis indexes parameter settings where an iterable is provided. """ - if parameters is None: - # unchanged parameters - param_iter = [None] - out_slice = 0 - elif hasattr(parameters, 'items'): - # one set of parameters - param_iter = [parameters] - out_slice = 0 - else: - # sequence of parameters - param_iter = parameters - out_slice = slice(None) - - out = self.parallel( - delayed(fit_fold)( - est_params=est_params, train=train, test=test, - **self.fit_fold_kwargs - ) - for est_params in param_iter for train, test in self.cv) - - means, out = self._format_results(out) - - return means[out_slice], out[out_slice] + means, folds = self.search([{}], estimator, X, y) + return means[0], folds[0] - def score_folds(self, parameters=None): - """ - Cross-validate the estimator, optionally with the given parameters, - and return the scores for each fold. + def search(self, candidates, estimator, X, y=None): + """Cross-validate the estimator for candidate parameter settings. Parameters ---------- - parameters : dict or iterable of dicts, optional - If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, - cross-validation is performed for each set of parameters. + candidates : iterable of dicts + The estimator will be cloned and have these parameters + set for each candidate. - Returns - ------- - fold_scores : array of floats - The score for each fold evaluated. The first axis indexes parameter - settings where an iterable is provided. - """ - return self(parameters)[1]['test_score'] + estimator : estimator object implementing 'fit' + The object to use to fit the data. - def score_means(self, parameters=None): - """ - Cross-validate the estimator, optionally with the given parameters, and - return the mean score across folds. + X : array-like of shape at least 2D + The data to fit. - Parameters - ---------- - parameters : dict or iterable of dicts, optional - If provided, the estimator will be cloned and have these parameters - set. If an iterable of parameter settings is given, - cross-validation is performed for each set of parameters. + y : array-like, optional + The target variable to try to predict in the case of + supervised learning. Returns ------- - means : float, or array of floats - The mean score over folds evaluated, or an array of means given an - iterable of parameter settings. + means : structured array of shape (n_candidates,) + This provides fields: + * ``test_score``, the mean test score across folds + fold_results : structured array of shape (n_candidates, n_folds) + For each cross-validation fold, this provides fields: + * ``test_score``, the score for this fold + * ``test_n_samples``, the number of samples in testing """ - return self(parameters)[0]['test_score'] + X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) + cv = check_cv(self.cv, X, y, classifier=is_classifier(estimator)) + cv = list(cv) + n_folds = len(cv) + if self.scoring is None and not hasattr(estimator, 'score'): + raise TypeError( + "If no scoring is specified, the estimator passed " + "should have a 'score' method. The estimator %s " + "does not." % estimator) + + out = self.parallel( + delayed(_fit_fold)( + estimator, X, y, + est_params=est_params, train=train, test=test, + scoring=self.scoring, verbose=self.verbose, + fit_params=self.fit_params + ) + for est_params in candidates for train, test in cv) + + means, folds = self._format_results(out, n_folds) + return means, folds def cross_val_score(estimator, X, y=None, scoring=None, cv=None, iid=True, @@ -1451,11 +1422,16 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, iid=True, fit_params : dict, optional Parameters to pass to the fit method of the estimator. + + Returns + ------- + scores : array of float, shape=[len(cv)] + Array of scores of the estimator for each run of the cross validation. """ - cv_eval = CVEvaluator( - estimator, X, y, scoring=scoring, cv=cv, n_jobs=n_jobs, + cv_score = CVScorer( + scoring=scoring, cv=cv, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, score_func=score_func) - return cv_eval.score_folds() + return cv_score(estimator, X, y)[1]['test_score'] def _permutation_test_score(estimator, X, y, cv, scorer): diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 09e8c2fe0e995..46c950b5170e3 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -20,7 +20,7 @@ from .base import BaseEstimator, clone from .base import MetaEstimatorMixin -from .cross_validation import CVEvaluator, fit_fold +from .cross_validation import CVScorer, _fit_fold from .externals.six import string_types, iterkeys from .utils import check_random_state, deprecated from .utils.validation import _num_samples, check_arrays @@ -187,8 +187,7 @@ def __len__(self): return self.n_iter -@deprecated('fit_grid_point is deprecated and will be removed in 0.15. ' - 'Use cross_validation.fit_fold instead.') +@deprecated('fit_grid_point is deprecated and will be removed in 0.15.') def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, verbose, loss_func=None, **fit_params): """Run fit on one set of parameters. @@ -237,7 +236,7 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - res = fit_fold( + res = _fit_fold( base_clf, X, y, train, test, scorer, verbose, loss_func=None, est_params=clf_params, fit_params=fit_params ) @@ -396,12 +395,12 @@ def _fit(self, X, y, parameter_iterator, **params): % (len(y), n_samples)) y = np.asarray(y) - cv_eval = CVEvaluator( - self.estimator, X, y, scoring=self.scorer_, cv=self.cv, - iid=self.iid, fit_params=self.fit_params, n_jobs=self.n_jobs, - pre_dispatch=self.pre_dispatch, verbose=self.verbose - ) - search_results, cv_results = cv_eval(parameter_iterator) + cv_eval = CVScorer(cv=self.cv, scoring=self.scorer_, + iid=self.iid, fit_params=self.fit_params, + n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch, + verbose=self.verbose) + search_results, cv_results = cv_eval.search(parameter_iterator, + self.estimator, X, y) # Append 'parameters' to search_results # Broken due to https://github.com/numpy/numpy/issues/2346: