diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index d26f6895427bb..ca1c185c91e06 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -1040,7 +1040,7 @@ def test_column_transformer_no_estimators_set_params(): def test_column_transformer_callable_specifier(): - # assert that function gets the full array / dataframe + # assert that function gets the full array X_array = np.array([[0, 1, 2], [2, 4, 6]]).T X_res_first = np.array([[0, 1, 2]]).T @@ -1055,7 +1055,13 @@ def func(X): assert callable(ct.transformers[0][2]) assert ct.transformers_[0][2] == [0] + +def test_column_transformer_callable_specifier_dataframe(): + # assert that function gets the full dataframe pd = pytest.importorskip('pandas') + X_array = np.array([[0, 1, 2], [2, 4, 6]]).T + X_res_first = np.array([[0, 1, 2]]).T + X_df = pd.DataFrame(X_array, columns=['first', 'second']) def func(X):