Thanks to visit codestin.com
Credit goes to github.com

Skip to content
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Estimators
- :class:`preprocessing.KernelCenterer`
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`
- :class:`preprocessing.Normalizer`

Metrics
-------
Expand Down
12 changes: 7 additions & 5 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,16 @@ Changelog
:mod:`sklearn.preprocessing`
............................

- |MajorFeature| :class:`preprocessing.MinMaxScaler`, :class:`preprocessing.MaxAbsScaler`
and :class:`preprocessing.KernelCenterer` now
support the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
- |MajorFeature| The following classes now support the
`Array API <https://data-apis.org/array-api/latest/>`_. Array API
support is considered experimental and might evolve without being subject to
our usual rolling deprecation cycle policy. See
:ref:`array_api` for more details.
:pr:`26243` by `Tim Head`_, :pr:`27110` by :user:`Edoardo Abati <EdAbati>` and
:pr:`27556` by :user:`Edoardo Abati <EdAbati>`.

- :class:`preprocessing.MinMaxScaler` :pr:`26243` by `Tim Head`_
- :class:`preprocessing.MaxAbsScaler` :pr:`27110` by :user:`Edoardo Abati <EdAbati>`
- :class:`preprocessing.KernelCenterer` :pr:`27556` by :user:`Edoardo Abati <EdAbati>`
- :class:`preprocessing.Normalizer` :pr:`27558` by :user:`Edoardo Abati <EdAbati>`

- |Efficiency| :class:`preprocessing.OrdinalEncoder` avoids calculating
missing indices twice to improve efficiency.
Expand Down
12 changes: 7 additions & 5 deletions sklearn/preprocessing/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,12 +1858,14 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
else: # axis == 1:
sparse_format = "csr"

xp, _ = get_namespace(X)

X = check_array(
X,
accept_sparse=sparse_format,
copy=copy,
estimator="the normalize function",
dtype=FLOAT_DTYPES,
dtype=_array_api.supported_float_dtypes(xp),
)
if axis == 0:
X = X.T
Expand All @@ -1887,13 +1889,13 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
X.data[mask] /= norms_elementwise[mask]
else:
if norm == "l1":
norms = np.abs(X).sum(axis=1)
norms = xp.sum(xp.abs(X), axis=1)
elif norm == "l2":
norms = row_norms(X)
elif norm == "max":
norms = np.max(abs(X), axis=1)
norms = xp.max(xp.abs(X), axis=1)
norms = _handle_zeros_in_scale(norms, copy=False)
X /= norms[:, np.newaxis]
X /= norms[:, None]

if axis == 0:
X = X.T
Expand Down Expand Up @@ -2031,7 +2033,7 @@ def transform(self, X, copy=None):
return normalize(X, norm=self.norm, axis=1, copy=copy)

def _more_tags(self):
return {"stateless": True}
return {"stateless": True, "array_api_support": True}


@validate_params(
Expand Down
9 changes: 8 additions & 1 deletion sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,14 @@ def test_standard_check_array_of_inverse_transform():
)
@pytest.mark.parametrize(
"estimator",
[MaxAbsScaler(), MinMaxScaler(), KernelCenterer()],
[
MaxAbsScaler(),
MinMaxScaler(),
KernelCenterer(),
Normalizer(norm="l1"),
Normalizer(norm="l2"),
Normalizer(norm="max"),
],
ids=_get_check_estimator_ids,
)
def test_scaler_array_api_compliance(estimator, check, array_namespace, device, dtype):
Expand Down
15 changes: 11 additions & 4 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,18 @@ def row_norms(X, squared=False):
if sparse.issparse(X):
X = X.tocsr()
norms = csr_row_norms(X)
if not squared:
norms = np.sqrt(norms)
else:
norms = np.einsum("ij,ij->i", X, X)

if not squared:
np.sqrt(norms, norms)
xp, _ = get_namespace(X)
if _is_numpy_namespace(xp):
X = np.asarray(X)
norms = np.einsum("ij,ij->i", X, X)
norms = xp.asarray(norms)
else:
norms = xp.sum(xp.multiply(X, X), axis=1)
if not squared:
norms = xp.sqrt(norms)
return norms


Expand Down