From 794dbf0b5520c09053b54a677a147f2a9ee932f2 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sat, 11 May 2024 21:47:48 +1000 Subject: [PATCH 1/2] complete warm start example --- doc/modules/ensemble.rst | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 4237d023973f7..6622414ddd68b 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -603,7 +603,21 @@ fitted model. :: - >>> _ = est.set_params(n_estimators=200, warm_start=True) # set warm_start and new nr of trees + >>> import numpy as np + >>> from sklearn.metrics import mean_squared_error + >>> from sklearn.datasets import make_friedman1 + >>> from sklearn.ensemble import GradientBoostingRegressor + + >>> X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0) + >>> X_train, X_test = X[:200], X[200:] + >>> y_train, y_test = y[:200], y[200:] + >>> est = GradientBoostingRegressor( + ... n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, + ... loss='squared_error' + ... ).fit(X_train, y_train) + >>> mean_squared_error(y_test, est.predict(X_test)) + 5.00... + >>> _ = est.set_params(n_estimators=200, warm_start=True) # set warm_start and increase num of trees >>> _ = est.fit(X_train, y_train) # fit additional 100 trees to est >>> mean_squared_error(y_test, est.predict(X_test)) 3.84... From 4d8124f36de709b5c06e10de31cac3e34384407d Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sun, 12 May 2024 13:56:02 +1000 Subject: [PATCH 2/2] review --- doc/modules/ensemble.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 6622414ddd68b..d18dd2f65009e 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -614,7 +614,8 @@ fitted model. >>> est = GradientBoostingRegressor( ... n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, ... loss='squared_error' - ... ).fit(X_train, y_train) + ... ) + >>> est = est.fit(X_train, y_train) # fit with 100 trees >>> mean_squared_error(y_test, est.predict(X_test)) 5.00... >>> _ = est.set_params(n_estimators=200, warm_start=True) # set warm_start and increase num of trees