Thanks to visit codestin.com
Credit goes to github.com

Skip to content

DOC Fix FutureWarning in ensemble/plot_gradient_boosting_regularization.html #24960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/ensemble/plot_gradient_boosting_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from sklearn import ensemble
from sklearn import datasets

from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

X, y = datasets.make_hastie_10_2(n_samples=4000, random_state=1)
Expand Down Expand Up @@ -74,9 +74,8 @@
# compute test set deviance
test_deviance = np.zeros((params["n_estimators"],), dtype=np.float64)

for i, y_pred in enumerate(clf.staged_decision_function(X_test)):
# clf.loss_ assumes that y_test[i] in {0, 1}
test_deviance[i] = clf.loss_(y_test, y_pred)
for i, y_proba in enumerate(clf.staged_predict_proba(X_test)):
test_deviance[i] = 2 * log_loss(y_test, y_proba[:, 1])

plt.plot(
(np.arange(test_deviance.shape[0]) + 1)[::5],
Expand Down