Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Allow column names to pass through when fitting narwhals dataframes #31019

Closed as not planned
@ryansheabla

Description

@ryansheabla

Describe the workflow you want to enable

Currently when fitting with a narwhals DataFrame, the feature names do not pass through because it does not implement a __dataframe__ method.

Example:

import narwhals as nw
import pandas as pd
import polars as pl
from sklearn.preprocessing import StandardScaler

df_pd = pd.DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]})
df_pl = pl.DataFrame(df_pd)
df_nw = nw.from_native(df_pd)

s_pd, s_pl, s_nw = StandardScaler(), StandardScaler(), StandardScaler()
s_pd.fit(df_pd)
s_pl.fit(df_pl)
s_nw.fit(df_nw)

print(s_pd.feature_names_in_)
print(s_pl.feature_names_in_)
print(s_nw.feature_names_in_)

Expected output

['a' 'b']
['a' 'b']
['a' 'b']

Actual output

['a' 'b']
['a' 'b']
AttributeError: 'StandardScaler' object has no attribute 'feature_names_in_'

All other attributes on s_nw are what I'd expect.

Describe your proposed solution

This should be easy enough to implement by adding another check within sklearn.utils.validation._get_feature_names:

  1. Add _is_narwhals_df method, borrowing logic from _is_pandas_df
def _is_narwhals_df(X):
    """Return True if the X is a narwhals dataframe."""
    try:
        nw = sys.modules["narwhals"]
    except KeyError:
        return False
    return isinstance(X, nw.DataFrame)
  1. Add an additional check to _get_feature_names:
    elif _is_narwhals_df(X):
        feature_names = np.asarray(X.columns, dtype=object)

Describe alternatives you've considered, if relevant

No response

Additional context

narwhals-dev/narwhals#355 (comment)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions