Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 88aebc7

Browse files
thomasjpfanogrisel
authored andcommitted
FIX SplineTransformer.get_feature_names_out returns correct names for extrapolations=periodic (scikit-learn#25296)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent eab63b7 commit 88aebc7

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ Changelog
6060
supports DataFrames that are all numerical when `check_inverse=True`.
6161
:pr:`25274` by `Thomas Fan`_.
6262

63+
- |Fix| :meth:`preprocessing.SplineTransformer.get_feature_names_out` correctly
64+
returns feature names when `extrapolations="periodic"`. :pr:`25296` by
65+
`Thomas Fan`_.
66+
6367
:mod:`sklearn.tree`
6468
...................
6569

sklearn/preprocessing/_polynomial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ def get_feature_names_out(self, input_features=None):
673673
feature_names_out : ndarray of str objects
674674
Transformed feature names.
675675
"""
676-
n_splines = self.bsplines_[0].c.shape[0]
676+
n_splines = self.bsplines_[0].c.shape[1]
677+
677678
input_features = _check_feature_names_in(self, input_features)
678679
feature_names = []
679680
for i in range(self.n_features_in_):

sklearn/preprocessing/tests/test_polynomial.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ def test_spline_transformer_feature_names():
7676
)
7777

7878

79+
@pytest.mark.parametrize(
80+
"extrapolation",
81+
["constant", "linear", "continue", "periodic"],
82+
)
83+
@pytest.mark.parametrize("degree", [2, 3])
84+
def test_split_transform_feature_names_extrapolation_degree(extrapolation, degree):
85+
"""Test feature names are correct for different extrapolations and degree.
86+
87+
Non-regression test for gh-25292.
88+
"""
89+
X = np.arange(20).reshape(10, 2)
90+
splt = SplineTransformer(degree=degree, extrapolation=extrapolation).fit(X)
91+
feature_names = splt.get_feature_names_out(["a", "b"])
92+
assert len(feature_names) == splt.n_features_out_
93+
94+
X_trans = splt.transform(X)
95+
assert X_trans.shape[1] == len(feature_names)
96+
97+
7998
@pytest.mark.parametrize("degree", range(1, 5))
8099
@pytest.mark.parametrize("n_knots", range(3, 5))
81100
@pytest.mark.parametrize("knots", ["uniform", "quantile"])

0 commit comments

Comments
 (0)