Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 30c86ea

Browse files
Move set_params back to fit_grid_point
1 parent 38081fd commit 30c86ea

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

sklearn/cross_validation.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .utils import check_arrays, check_random_state, safe_mask
2626
from .utils.validation import _num_samples
2727
from .utils.fixes import unique
28-
from .externals.joblib import Parallel, delayed, logger
28+
from .externals.joblib import Parallel, delayed
2929
from .externals.six import string_types, with_metaclass
3030
from .metrics.scorer import _deprecate_loss_and_score_funcs
3131

@@ -1096,18 +1096,14 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
10961096
pre_dispatch=pre_dispatch)
10971097
scores = parallel(
10981098
delayed(_cross_val_score)(clone(estimator), X, y, scorer, train, test,
1099-
parameters=None, verbose=verbose,
1100-
fit_params=fit_params,
1101-
log_label="cross_val_score")
1099+
verbose=verbose, fit_params=fit_params)
11021100
for train, test in cv)
11031101
return np.array(scores)[:, 0]
11041102

11051103

1106-
def _cross_val_score(estimator, X, y, scorer, train, test, parameters,
1107-
verbose, fit_params, log_label):
1104+
def _cross_val_score(estimator, X, y, scorer, train, test,
1105+
verbose, fit_params):
11081106
"""Inner loop for cross validation"""
1109-
if parameters is not None:
1110-
estimator.set_params(**parameters)
11111107
n_samples = _num_samples(X)
11121108
fit_params = fit_params if fit_params is not None else {}
11131109
fit_params = dict([(k, np.asarray(v)[train]
@@ -1116,25 +1112,12 @@ def _cross_val_score(estimator, X, y, scorer, train, test, parameters,
11161112

11171113
start_time = time.time()
11181114

1119-
if verbose > 1:
1120-
if parameters is None:
1121-
msg = ""
1122-
else:
1123-
msg = '%s' % (', '.join('%s=%s' % (k, v)
1124-
for k, v in parameters.items()))
1125-
print("[%s] %s %s" % (log_label, msg, (64 - len(msg)) * '.'))
1126-
11271115
X_train, y_train = _split(estimator, X, y, train)
11281116
X_test, y_test = _split(estimator, X, y, test, train)
11291117
_fit(estimator.fit, X_train, y_train, **fit_params)
11301118
score = _score(estimator, X_test, y_test, scorer)
11311119

11321120
scoring_time = time.time() - start_time
1133-
if verbose > 2:
1134-
msg += ", score=%f" % score
1135-
if verbose > 1:
1136-
end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time))
1137-
print("[%s] %s %s" % (log_label, (64 - len(end_msg)) * '.', end_msg))
11381121

11391122
return score, _num_samples(X_test), scoring_time
11401123

sklearn/grid_search.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .base import MetaEstimatorMixin
2525
from .cross_validation import _check_cv as check_cv
2626
from .cross_validation import _check_scorable, _cross_val_score
27-
from .externals.joblib import Parallel, delayed
27+
from .externals.joblib import Parallel, delayed, logger
2828
from .externals import six
2929
from .utils import safe_mask, check_random_state
3030
from .utils.validation import _num_samples, check_arrays
@@ -229,10 +229,21 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
229229
n_samples_test : int
230230
Number of test samples in this split.
231231
"""
232-
score, n_samples_test, _ = _cross_val_score(estimator, X, y, scorer,
233-
train, test, parameters,
234-
verbose, fit_params,
235-
log_label="GridSearchCV")
232+
if verbose > 1:
233+
msg = '%s' % (', '.join('%s=%s' % (k, v)
234+
for k, v in parameters.items()))
235+
print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.'))
236+
237+
estimator.set_params(**parameters)
238+
score, n_samples_test, scoring_time = _cross_val_score(
239+
estimator, X, y, scorer, train, test, verbose, fit_params)
240+
241+
if verbose > 2:
242+
msg += ", score=%f" % score
243+
if verbose > 1:
244+
end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time))
245+
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
246+
236247
return score, parameters, n_samples_test
237248

238249

0 commit comments

Comments
 (0)