From f08e20f0687f388ab07a58da2e62c9458899bdbf Mon Sep 17 00:00:00 2001 From: Christian Date: Sat, 3 Sep 2022 20:06:03 +0200 Subject: [PATCH] DOC ensure sklearn/utils/extmath/stable_cumsum passes numpydoc --- sklearn/tests/test_docstrings.py | 1 - sklearn/utils/extmath.py | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_docstrings.py b/sklearn/tests/test_docstrings.py index 6118d4827c523..31554e765b18c 100644 --- a/sklearn/tests/test_docstrings.py +++ b/sklearn/tests/test_docstrings.py @@ -36,7 +36,6 @@ "sklearn.utils.extmath.randomized_svd", "sklearn.utils.extmath.safe_sparse_dot", "sklearn.utils.extmath.squared_norm", - "sklearn.utils.extmath.stable_cumsum", "sklearn.utils.extmath.svd_flip", "sklearn.utils.extmath.weighted_mode", "sklearn.utils.fixes.delayed", diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index e4513a62bf07e..bc4e18e7d55e3 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1057,6 +1057,9 @@ def _deterministic_vector_sign_flip(u): def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """Use high precision for cumsum and check that final value matches sum. + Warns if the final cumulative sum does not match the sum (up to the chosen + tolerance). + Parameters ---------- arr : array-like @@ -1068,6 +1071,11 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): Relative tolerance, see ``np.allclose``. atol : float, default=1e-08 Absolute tolerance, see ``np.allclose``. + + Returns + ------- + out : ndarray + Array with the cumulative sums along the chosen axis. """ out = np.cumsum(arr, axis=axis, dtype=np.float64) expected = np.sum(arr, axis=axis, dtype=np.float64)