Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e58f366

Browse files
janvanrijnrth
authored andcommitted
ColumnTransformer generalization to work on empty lists (#12084)
1 parent 661a8b4 commit e58f366

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

sklearn/compose/_column_transformer.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
108108
transformers_ : list
109109
The collection of fitted transformers as tuples of
110110
(name, fitted_transformer, column). `fitted_transformer` can be an
111-
estimator, 'drop', or 'passthrough'. If there are remaining columns,
112-
the final element is a tuple of the form:
111+
estimator, 'drop', or 'passthrough'. In case there were no columns
112+
selected, this will be the unfitted transformer.
113+
If there are remaining columns, the final element is a tuple of the
114+
form:
113115
('remainder', transformer, remaining_columns) corresponding to the
114116
``remainder`` parameter. If there are remaining columns, then
115117
``len(transformers_)==len(transformers)+1``, otherwise
@@ -242,6 +244,8 @@ def _iter(self, fitted=False, replace_strings=False):
242244
check_inverse=False)
243245
elif trans == 'drop':
244246
continue
247+
elif _is_empty_column_selection(column):
248+
continue
245249

246250
yield (name, trans, column, get_weight(name))
247251

@@ -350,6 +354,8 @@ def _update_fitted_transformers(self, transformers):
350354
# so get next transformer, but save original string
351355
next(fitted_transformers)
352356
trans = 'passthrough'
357+
elif _is_empty_column_selection(column):
358+
trans = old
353359
else:
354360
trans = next(fitted_transformers)
355361
transformers_.append((name, trans, column))
@@ -652,6 +658,20 @@ def _get_column_indices(X, key):
652658
"strings, or boolean mask is allowed")
653659

654660

661+
def _is_empty_column_selection(column):
662+
"""
663+
Return True if the column selection is empty (empty list or all-False
664+
boolean array).
665+
666+
"""
667+
if hasattr(column, 'dtype') and np.issubdtype(column.dtype, np.bool_):
668+
return not column.any()
669+
elif hasattr(column, '__len__'):
670+
return len(column) == 0
671+
else:
672+
return False
673+
674+
655675
def _get_transformer_list(estimators):
656676
"""
657677
Construct (name, trans, column) tuples from list

sklearn/compose/tests/test_column_transformer.py

+45
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,51 @@ def transform(self, X, y=None):
251251
assert_array_equal(ct.transformers_[-1][2], [1])
252252

253253

254+
@pytest.mark.parametrize("pandas", [True, False], ids=['pandas', 'numpy'])
255+
@pytest.mark.parametrize("column", [[], np.array([False, False])],
256+
ids=['list', 'bool'])
257+
def test_column_transformer_empty_columns(pandas, column):
258+
# test case that ensures that the column transformer does also work when
259+
# a given transformer doesn't have any columns to work on
260+
X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
261+
X_res_both = X_array
262+
263+
if pandas:
264+
pd = pytest.importorskip('pandas')
265+
X = pd.DataFrame(X_array, columns=['first', 'second'])
266+
else:
267+
X = X_array
268+
269+
ct = ColumnTransformer([('trans1', Trans(), [0, 1]),
270+
('trans2', Trans(), column)])
271+
assert_array_equal(ct.fit_transform(X), X_res_both)
272+
assert_array_equal(ct.fit(X).transform(X), X_res_both)
273+
assert len(ct.transformers_) == 2
274+
assert isinstance(ct.transformers_[1][1], Trans)
275+
276+
ct = ColumnTransformer([('trans1', Trans(), column),
277+
('trans2', Trans(), [0, 1])])
278+
assert_array_equal(ct.fit_transform(X), X_res_both)
279+
assert_array_equal(ct.fit(X).transform(X), X_res_both)
280+
assert len(ct.transformers_) == 2
281+
assert isinstance(ct.transformers_[0][1], Trans)
282+
283+
ct = ColumnTransformer([('trans', Trans(), column)],
284+
remainder='passthrough')
285+
assert_array_equal(ct.fit_transform(X), X_res_both)
286+
assert_array_equal(ct.fit(X).transform(X), X_res_both)
287+
assert len(ct.transformers_) == 2 # including remainder
288+
assert isinstance(ct.transformers_[0][1], Trans)
289+
290+
fixture = np.array([[], [], []])
291+
ct = ColumnTransformer([('trans', Trans(), column)],
292+
remainder='drop')
293+
assert_array_equal(ct.fit_transform(X), fixture)
294+
assert_array_equal(ct.fit(X).transform(X), fixture)
295+
assert len(ct.transformers_) == 2 # including remainder
296+
assert isinstance(ct.transformers_[0][1], Trans)
297+
298+
254299
def test_column_transformer_sparse_array():
255300
X_sparse = sparse.eye(3, 2).tocsr()
256301

0 commit comments

Comments
 (0)