Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX Preserve y shape in TransformedTargetRegressor #31563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kryggird
Copy link

Reference Issues/PRs

Fixes 26530.

What does this implement/fix? Explain your changes.

This PR uses the existing self._training_dim to decide whether to squeeze y before passing it the inner regressor in TransformedTargetRegressor.

Any other comments?

I've also added a test in test_metaestimators_metadata_routing.py.

@betatim @glemaitre

Copy link

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff check

ruff detected issues. Please run ruff check --fix --output-format=full locally, fix the remaining issues, and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


sklearn/tests/test_metaestimators_metadata_routing.py:1:1: I001 [*] Import block is un-sorted or un-formatted
   |
 1 | / import copy
 2 | | import re
 3 | |
 4 | | import numpy as np
 5 | | import pytest
 6 | |
 7 | | from ..base import (
 8 | |         BaseEstimator,
 9 | |         RegressorMixin
10 | | )
11 | |
12 | | from sklearn import config_context
13 | | from sklearn.base import BaseEstimator, is_classifier
14 | | from sklearn.calibration import CalibratedClassifierCV
15 | | from sklearn.compose import TransformedTargetRegressor
16 | | from sklearn.covariance import GraphicalLassoCV
17 | | from sklearn.ensemble import (
18 | |     AdaBoostClassifier,
19 | |     AdaBoostRegressor,
20 | |     BaggingClassifier,
21 | |     BaggingRegressor,
22 | | )
23 | | from sklearn.exceptions import UnsetMetadataPassedError
24 | | from sklearn.experimental import (
25 | |     enable_halving_search_cv,  # noqa: F401
26 | |     enable_iterative_imputer,  # noqa: F401
27 | | )
28 | | from sklearn.feature_selection import (
29 | |     RFE,
30 | |     RFECV,
31 | |     SelectFromModel,
32 | |     SequentialFeatureSelector,
33 | | )
34 | | from sklearn.impute import IterativeImputer
35 | | from sklearn.linear_model import (
36 | |     ElasticNetCV,
37 | |     LarsCV,
38 | |     LassoCV,
39 | |     LassoLarsCV,
40 | |     LogisticRegressionCV,
41 | |     MultiTaskElasticNetCV,
42 | |     MultiTaskLassoCV,
43 | |     OrthogonalMatchingPursuitCV,
44 | |     RANSACRegressor,
45 | |     RidgeClassifierCV,
46 | |     RidgeCV,
47 | | )
48 | | from sklearn.metrics._regression import mean_squared_error
49 | | from sklearn.metrics._scorer import make_scorer
50 | | from sklearn.model_selection import (
51 | |     FixedThresholdClassifier,
52 | |     GridSearchCV,
53 | |     GroupKFold,
54 | |     HalvingGridSearchCV,
55 | |     HalvingRandomSearchCV,
56 | |     RandomizedSearchCV,
57 | |     TunedThresholdClassifierCV,
58 | |     cross_validate,
59 | | )
60 | | from sklearn.multiclass import (
61 | |     OneVsOneClassifier,
62 | |     OneVsRestClassifier,
63 | |     OutputCodeClassifier,
64 | | )
65 | | from sklearn.multioutput import (
66 | |     ClassifierChain,
67 | |     MultiOutputClassifier,
68 | |     MultiOutputRegressor,
69 | |     RegressorChain,
70 | | )
71 | | from sklearn.preprocessing import FunctionTransformer
72 | | from sklearn.semi_supervised import SelfTrainingClassifier
73 | | from sklearn.tests.metadata_routing_common import (
74 | |     ConsumingClassifier,
75 | |     ConsumingRegressor,
76 | |     ConsumingScorer,
77 | |     ConsumingSplitter,
78 | |     NonConsumingClassifier,
79 | |     NonConsumingRegressor,
80 | |     _Registry,
81 | |     assert_request_is_empty,
82 | |     check_recorded_metadata,
83 | | )
84 | | from sklearn.utils.metadata_routing import MetadataRouter
   | |_________________________________________________________^ I001
85 |
86 |   rng = np.random.RandomState(42)
   |
   = help: Organize imports

sklearn/tests/test_metaestimators_metadata_routing.py:13:26: F811 Redefinition of unused `BaseEstimator` from line 8
   |
12 | from sklearn import config_context
13 | from sklearn.base import BaseEstimator, is_classifier
   |                          ^^^^^^^^^^^^^ F811
14 | from sklearn.calibration import CalibratedClassifierCV
15 | from sklearn.compose import TransformedTargetRegressor
   |
   = help: Remove definition: `BaseEstimator`

Found 2 errors.
[*] 1 fixable with the `--fix` option.

ruff format

ruff detected issues. Please run ruff format locally and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


--- sklearn/tests/test_metaestimators_metadata_routing.py
+++ sklearn/tests/test_metaestimators_metadata_routing.py
@@ -4,10 +4,7 @@
 import numpy as np
 import pytest
 
-from ..base import (
-        BaseEstimator,
-        RegressorMixin
-)
+from ..base import BaseEstimator, RegressorMixin
 
 from sklearn import config_context
 from sklearn.base import BaseEstimator, is_classifier
@@ -933,12 +930,13 @@
         scoring=make_scorer(mean_squared_error, response_method="predict"),
     )
 
+
 class ValidateDimensionRegressor(BaseEstimator, RegressorMixin):
     def __init__(self, ndim):
         self.ndim = ndim
 
     def fit(self, X, y):
-        assert(y.ndim == self.ndim)
+        assert y.ndim == self.ndim
 
     def predict(self, X):
         pass

1 file would be reformatted, 923 files already formatted

Generated for commit: bf931c0. Link to the linter CI: here

@betatim
Copy link
Member

betatim commented Jun 17, 2025

Thanks for the Pull Request @kryggird! If you have the time, could you take a look at the linter's complaints - it has instructions on how to install and run the linters as well which in 99.9% will automagically fix all the complaints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TransformedTargetRegressor forces 1d y shape to regressor
2 participants