Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a53f8e4

Browse files
committed
feature: add get_feature_names() and tests to FunctionTransformer
1 parent 3f88b98 commit a53f8e4

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

sklearn/preprocessing/_function_transformer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,31 @@ def fit(self, X, y=None):
6464
check_array(X, self.accept_sparse)
6565
return self
6666

67+
def get_feature_names(self, input_features=None):
68+
"""
69+
Return feature names for output features
70+
71+
Parameters
72+
----------
73+
input_features : list of string, length len(input_features), optional
74+
String names for input features if available. By default,
75+
None is used.
76+
77+
Returns
78+
-------
79+
output_feature_names : list of string, length len(input_features)
80+
81+
"""
82+
if input_features is not None:
83+
input_features = list(map(lambda input_feature: 'f(' +
84+
input_feature + ')',
85+
input_features))
86+
return input_features
87+
6788
def transform(self, X, y=None):
6889
if self.validate:
6990
X = check_array(X, self.accept_sparse)
7091
func = self.func if self.func is not None else _identity
7192

72-
7393
return func(X, *((y,) if self.pass_y else ()),
7494
**(self.kw_args if self.kw_args else {}))

sklearn/preprocessing/tests/test_function_transformer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ def test_delegate_to_func():
7474
)
7575

7676

77+
def test_get_feature_names():
78+
F = FunctionTransformer()
79+
feature_names = F.get_feature_names(["a", "b", "c"])
80+
testing.assert_array_equal(['f(a)', 'f(b)', 'f(c)'], feature_names)
81+
feature_names_default = F.get_feature_names()
82+
testing.assert_array_equal(None, feature_names_default)
83+
84+
7785
def test_np_log():
7886
X = np.arange(10).reshape((5, 2))
7987

@@ -91,7 +99,7 @@ def test_kw_arg():
9199

92100
# Test that rounding is correct
93101
testing.assert_array_equal(F.transform(X),
94-
np.around(X, decimals=3))
102+
np.around(X, decimals=3))
95103

96104

97105
def test_kw_arg_update():
@@ -100,10 +108,9 @@ def test_kw_arg_update():
100108
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
101109

102110
F.kw_args['decimals'] = 1
103-
104111
# Test that rounding is correct
105112
testing.assert_array_equal(F.transform(X),
106-
np.around(X, decimals=1))
113+
np.around(X, decimals=1))
107114

108115

109116
def test_kw_arg_reset():
@@ -115,4 +122,4 @@ def test_kw_arg_reset():
115122

116123
# Test that rounding is correct
117124
testing.assert_array_equal(F.transform(X),
118-
np.around(X, decimals=1))
125+
np.around(X, decimals=1))

0 commit comments

Comments
 (0)