Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX poisson proxy_impurity_improvement #22191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 26, 2022

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Fixes #22186.

What does this implement/fix? Explain your changes.

This fixes proxy_impurity_improvement for the Poisson splitting criterion in DecisionTreeRegressor.

Any other comments?

Test now pass with tighter bounds.

@RAMitchell
Copy link

Does the proxy_impurity_improvement return value need to be scaled by number of instances? Not exactly sure how sklearn operates, but you might want to check this scaled correctly so that any 'min_impurity' type checks are consistent with the loss function.

@lorentzenchr lorentzenchr added this to the 1.1 milestone Jan 13, 2022
Comment on lines +1414 to +1415
proxy_impurity_left -= self.sum_left[k] * log(y_mean_left)
proxy_impurity_right -= self.sum_right[k] * log(y_mean_right)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix.

@lorentzenchr
Copy link
Member Author

@glemaitre @thomasjpfan Friendly ping.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR @lorentzenchr !

I think this bug is very subtle. From my understanding, the proxy_impurity_improvement API expects the impurities to be "scaled back up":

self.children_impurity(&impurity_left, &impurity_right)
return (- self.weighted_n_right * impurity_right
- self.weighted_n_left * impurity_left)

For the Poisson case, the weighted_n_* gets canceled out.

Is this your understanding of the bug as well?

@lorentzenchr
Copy link
Member Author

For the Poisson case, the weighted_n_* gets canceled out.
Is this your understanding of the bug as well?

Yes, the weights of the (candidate) left and right nodes get cancelled out in the lines

self.children_impurity(&impurity_left, &impurity_right)
return (- self.weighted_n_right * impurity_right
- self.weighted_n_left * impurity_left)

Example: impurity = MSE = 1/n sum_i (y_i - y_pred_i)**2

  • For a single tree node, y_pred_i = const = y_mean = 1/n sum_i y_i and therefore MSE = 1/n sum_i (y_i - y_mean)**2 = var(y)
  • If one divides the node into left and right, one has different values for y_pred (left/right) and so the MSE of the parent can be rewritten as MSE(parent) = 1/n * (n_L * MSE(left) + n_R * MSE(right)) which can then be further simplified. The point is that n_L * MSE = sum of squares instead of mean of squares.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@glemaitre
Copy link
Member

I also see a comment at the beginning of the class regarding the proxy computation. Do you want to remove it then?

@ogrisel
Copy link
Member

ogrisel commented Jan 26, 2022

For information, I tried to test the impact of this PR on the poisson regression example : https://scikit-learn.org/stable/auto_examples/linear_model/plot_poisson_regression_non_normal_loss.html

I defined a poisson_rf model, similar to the poisson_gbrt` pipeline but with:

from sklearn.ensemble import RandomForestRegressor


poisson_rf = Pipeline(
    [
        ("preprocessor", tree_preprocessor),
        (
            "regressor",
            RandomForestRegressor(criterion="poisson", min_samples_leaf=10, n_jobs=-1),
        ),
    ]
)
poisson_rf.fit(
    df_train, df_train["Frequency"], regressor__sample_weight=df_train["Exposure"]
)

print("Poisson Random Forest evaluation:")
score_estimator(poisson_rf, df_test)

and then computed the deviance when training on 10% of the data (due to longer training times of RF compared to linear models and GBDT) and I found that:

  • a this PR can improve the Poisson deviance a bit (compared to main) but not by much (not sure if it's in the statistical noise);
  • b Poisson RF cannot compete with Poisson GBRT nor even linear Poisson regression on this data (or even Ridge regression), both in terms of Poisson deviance or Gini....
  • c MSE RF is not necessarily worse than Poisson RF in terms of Poisson deviance...

I find b a bit worrying but maybe this is expected?

Disclaimer: I did not run a full hyper-parameter optimization for RF models. In particular, it's possible that the RF models would perform better with more trees (but those are slow...).

Edit: I made a mistake in my first batch of experiments where I forgot to recompile when switching branch... but after re-compiling, the previous comments still mostly hold.

But in any case, I made the following surprising observation with RFs: on this example, the following RF:

            RandomForestRegressor(
                criterion="poisson", n_estimators=5, min_samples_leaf=1000, n_jobs=-1
            )

performs as good as the same RF with hundreds of trees and much better than deep RF with lower values for min_samples_leaf. With those hyperparameters, RF are competitive with linear models and GBRT.

I wonder if a simple averaging the y_hats in RF is optimal for Poisson regression.

@ogrisel
Copy link
Member

ogrisel commented Jan 26, 2022

I confirm that this PR fixes the case originally reported as #22186 (comment):

image

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good, thanks for the new comments.

Based on #22191 (comment), the original problem is fixed, +1 for merge.

@ogrisel ogrisel merged commit 2b15b90 into scikit-learn:main Jan 26, 2022
- |Fix| Fix a bug in the Poisson splitting criterion for
:class:`tree.DecisionTreeRegressor`.
:pr:`22191` by :user:`Christian Lorentzen <lorentzenchr>`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops I realize that I merged too quickly: the sklearn.svm section has been split. Let me open a PR to fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect Poisson objective for decision tree/random forest
5 participants