Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 12d8261

Browse files
committed
ENH Restructure grid_scores_ into future proof eff. data structure
1 parent 1d487fb commit 12d8261

1 file changed

Lines changed: 156 additions & 52 deletions

File tree

sklearn/model_selection/_search.py

Lines changed: 156 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -320,24 +320,57 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
320320
return score, parameters, n_samples_test
321321

322322

323-
def _check_param_grid(param_grid):
324-
if hasattr(param_grid, 'items'):
325-
param_grid = [param_grid]
323+
def _check_param_grid_or_dist(param_grid_or_dist):
324+
"""Validate param_grid/distribution and return the unique parameters"""
325+
parameter_names = set()
326326

327-
for p in param_grid:
327+
if hasattr(param_grid_or_dist, 'items'):
328+
param_grid_or_dist = [param_grid_or_dist]
329+
330+
for p in param_grid_or_dist:
328331
for v in p.values():
329332
if isinstance(v, np.ndarray) and v.ndim > 1:
330333
raise ValueError("Parameter array should be one-dimensional.")
331334

332-
check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
333-
if True not in check:
335+
if not isinstance(v, (list, tuple, np.ndarray)):
334336
raise ValueError("Parameter values should be a list.")
335337

336338
if len(v) == 0:
337339
raise ValueError("Parameter values should be a non-empty "
338340
"list.")
339341

342+
parameter_names.update(p.keys())
343+
344+
return list(parameter_names)
345+
346+
347+
def _get_metric_names(scoring):
348+
"""Generate the list of metric name(s) given the scoring parameter"""
349+
metric_names = list()
350+
# XXX Do we index from 0?
351+
# NOTE we need this to prevent collisions between similarly named
352+
# custom metric (i.e [foo.bar, bar])
353+
n_custom_metrics = 1
340354

355+
if not isinstance(scoring, (list, tuple)):
356+
scoring = [scoring]
357+
358+
for metric in scoring:
359+
if callable(metric):
360+
metric_names.append("custom_metric_%s_%s" %
361+
(n_custom_metrics, metric.__name__))
362+
n_custom_metrics += 1
363+
364+
elif isinstance(metric, six.string_types):
365+
metric_names.append(metric)
366+
367+
else:
368+
raise ValueError("Unknown metric type - %r" % type(metric))
369+
370+
return metric_names
371+
372+
373+
# XXX Remove in 0.20
341374
class _CVScoreTuple (namedtuple('_CVScoreTuple',
342375
('parameters',
343376
'mean_validation_score',
@@ -380,6 +413,7 @@ def __init__(self, estimator, scoring=None,
380413
self.verbose = verbose
381414
self.pre_dispatch = pre_dispatch
382415
self.error_score = error_score
416+
self.metric_names_ = _get_metric_names(scoring)
383417

384418
@property
385419
def _estimator_type(self):
@@ -520,6 +554,12 @@ def inverse_transform(self, Xt):
520554
"""
521555
return self.best_estimator_.transform(Xt)
522556

557+
@property
558+
@deprecated("The grid_scores_ attribute is deprecated in favor of the "
559+
"search_results_ and will be removed in version 0.20.")
560+
def grid_scores_(self):
561+
return self._grid_scores
562+
523563
def _fit(self, X, y, labels, parameter_iterable):
524564
"""Actual fitting, performing the search over parameters."""
525565

@@ -560,38 +600,67 @@ def _fit(self, X, y, labels, parameter_iterable):
560600
# Out is a list of triplet: score, estimator, n_test_samples
561601
n_fits = len(out)
562602

563-
scores = list()
564-
grid_scores = list()
565-
for grid_start in range(0, n_fits, n_splits):
566-
n_test_samples = 0
567-
score = 0
568-
all_scores = []
569-
for this_score, this_n_test_samples, _, parameters in \
570-
out[grid_start:grid_start + n_splits]:
571-
all_scores.append(this_score)
603+
self._grid_scores = list()
604+
605+
# XXX Do we want to store these?
606+
n_candidates = n_fits / n_splits
607+
n_parameters = len(self.parameter_names_)
608+
n_metrics = len(scoring)
609+
610+
search_results_ = dict()
611+
612+
for param in self.parameter_names_:
613+
search_results_[param] = np.empty((n_candidates,), dtype=object)
614+
615+
for metric in self.metric_names_:
616+
# Make a column for each split
617+
# XXX To make it future proof
618+
for split_i in range(n_splits)]:
619+
search_results_["%s_split_%s" % (metric, split_i)] = (
620+
np.empty((n_candidates,), dtype=np.float32))
621+
622+
search_results_["%s_aggregated"] = np.empty((n_candidates,),
623+
dtype=np.float32)
624+
search_results_["%s_rank"] = np.empty((n_candidates,), dtype=int)
625+
626+
for grid_start in range(0, n_fits, n_splits):
627+
n_test_samples = 0
628+
aggregated_score = 0
629+
all_scores = []
630+
631+
# XXX Loop this when multiple metric support is enabled
632+
for (this_score, this_n_test_samples, _, parameters), i in \
633+
enumerate(out[grid_start:grid_start + n_splits]):
634+
all_scores.append(this_score)
635+
636+
if self.iid:
637+
this_score *= this_n_test_samples
638+
n_test_samples += this_n_test_samples
639+
aggregated_score += this_score
640+
search_results_["%s_split_%s" % (metric, i)] = this_score
641+
572642
if self.iid:
573-
this_score *= this_n_test_samples
574-
n_test_samples += this_n_test_samples
575-
score += this_score
576-
if self.iid:
577-
score /= float(n_test_samples)
578-
else:
579-
score /= float(n_splits)
580-
scores.append((score, parameters))
581-
# TODO: shall we also store the test_fold_sizes?
582-
grid_scores.append(_CVScoreTuple(
643+
aggregated_score /= float(n_test_samples)
644+
else:
645+
aggregated_score /= float(n_splits)
646+
647+
search_results_["%s_aggregated" % metric] = aggregated_score
648+
649+
# XXX Remove in version 0.20
650+
self._grid_scores.append(_CVScoreTuple(
583651
parameters,
584652
score,
585653
np.array(all_scores)))
586-
# Store the computed scores
587-
self.grid_scores_ = grid_scores
588654

589-
# Find the best parameters by comparing on the mean validation score:
590-
# note that `sorted` is deterministic in the way it breaks ties
591-
best = sorted(grid_scores, key=lambda x: x.mean_validation_score,
592-
reverse=True)[0]
593-
self.best_params_ = best.parameters
594-
self.best_score_ = best.mean_validation_score
655+
# Find the best parameters by comparing on the mean validation score:
656+
# note that `sorted` is deterministic in the way it breaks ties
657+
np.argsort(search_results_["%s"])
658+
search_results_["%s_aggregated" % metric] = aggregated_score
659+
660+
best = sorted(grid_scores, key=lambda x: x.mean_validation_score,
661+
reverse=True)[0]
662+
self.best_params_ = best.parameters
663+
self.best_score_ = best.mean_validation_score
595664

596665
if self.refit:
597666
# fit the best estimator using the entire dataset
@@ -722,15 +791,32 @@ class GridSearchCV(BaseSearchCV):
722791
723792
Attributes
724793
----------
725-
grid_scores_ : list of named tuples
726-
Contains scores for all parameter combinations in param_grid.
727-
Each entry corresponds to one parameter setting.
728-
Each named tuple has the attributes:
729-
730-
* ``parameters``, a dict of parameter settings
731-
* ``mean_validation_score``, the mean score over the
732-
cross-validation folds
733-
* ``cv_validation_scores``, the list of scores for each fold
794+
search_results_ : dict of numpy (masked) ndarrays
795+
A dict with keys as column headers and values as columns, that can be
796+
imported into a pandas DataFrame.
797+
798+
For instance the below given table
799+
800+
kernel|gamma|degree|accuracy_score_split_0...|accuracy_score_mean ...|
801+
=====================================================================
802+
'poly'| - | 2 | 0.8 | 0.81 |
803+
'poly'| - | 3 | 0.7 | 0.60 |
804+
'rbf' | 0.1 | - | 0.8 | 0.75 |
805+
'rbf' | 0.2 | - | 0.9 | 0.82 |
806+
807+
will be represented by a search_results_ dict of :
808+
809+
{'kernel' : masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
810+
mask = [False False False False]...)
811+
'gamma' : masked_array(data = [-- -- 0.1 0.2],
812+
mask = [ True True False False]...),
813+
'degree' : masked_array(data = [2.0 3.0 -- --],
814+
mask = [False False True True]...),
815+
'accuracy_score_split_0' : [0.8, 0.7, 0.8, 0.9],
816+
'accuracy_score_split_1' : [0.82, 0.5, 0.7, 0.78],
817+
'accuracy_score_mean' : [0.81, 0.60, 0.75, 0.82],
818+
'candidate_rank' : [2, 4, 3, 1],
819+
}
734820
735821
best_estimator_ : estimator
736822
Estimator that was chosen by the search, i.e. estimator
@@ -784,7 +870,7 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
784870
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
785871
pre_dispatch=pre_dispatch, error_score=error_score)
786872
self.param_grid = param_grid
787-
_check_param_grid(param_grid)
873+
self.parameter_names_ = _check_param_grid_or_dist(param_grid)
788874

789875
def fit(self, X, y=None, labels=None):
790876
"""Run fit with all sets of parameters.
@@ -918,15 +1004,32 @@ class RandomizedSearchCV(BaseSearchCV):
9181004
9191005
Attributes
9201006
----------
921-
grid_scores_ : list of named tuples
922-
Contains scores for all parameter combinations in param_grid.
923-
Each entry corresponds to one parameter setting.
924-
Each named tuple has the attributes:
925-
926-
* ``parameters``, a dict of parameter settings
927-
* ``mean_validation_score``, the mean score over the
928-
cross-validation folds
929-
* ``cv_validation_scores``, the list of scores for each fold
1007+
search_results_ : dict of numpy (masked) ndarrays
1008+
A dict with keys as column headers and values as columns, that can be
1009+
imported into a pandas DataFrame.
1010+
1011+
For instance the below given table
1012+
1013+
kernel|gamma|degree|accuracy_score_split_0...|accuracy_score_mean ...|
1014+
=====================================================================
1015+
'poly'| - | 2 | 0.8 | 0.81 |
1016+
'poly'| - | 3 | 0.7 | 0.60 |
1017+
'rbf' | 0.1 | - | 0.8 | 0.75 |
1018+
'rbf' | 0.2 | - | 0.9 | 0.82 |
1019+
1020+
will be represented by a search_results_ dict of :
1021+
1022+
{'kernel' : masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
1023+
mask = [False False False False]...)
1024+
'gamma' : masked_array(data = [-- -- 0.1 0.2],
1025+
mask = [ True True False False]...),
1026+
'degree' : masked_array(data = [2.0 3.0 -- --],
1027+
mask = [False False True True]...),
1028+
'accuracy_score_split_0' : [0.8, 0.7, 0.8, 0.9],
1029+
'accuracy_score_split_1' : [0.82, 0.5, 0.7, 0.78],
1030+
'accuracy_score_mean' : [0.81, 0.60, 0.75, 0.82],
1031+
'candidate_rank' : [2, 4, 3, 1],
1032+
}
9301033
9311034
best_estimator_ : estimator
9321035
Estimator that was chosen by the search, i.e. estimator
@@ -969,6 +1072,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
9691072
error_score='raise'):
9701073

9711074
self.param_distributions = param_distributions
1075+
self.parameter_names_ = _check_param_grid_or_dist(param_distributions)
9721076
self.n_iter = n_iter
9731077
self.random_state = random_state
9741078
super(RandomizedSearchCV, self).__init__(

0 commit comments

Comments
 (0)