diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 2ff9dfd776a03..1d45b035333c5 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -873,6 +873,30 @@ def test_scaler_without_copy(): assert_array_equal(X_csc.toarray(), X_csc_copy.toarray()) +def test_scaler_partial_fit_overflow(): + # Test StandardScaler does not overflow in partial_fit #5602 + rng = np.random.RandomState(0) + + def gen_1d_uniform_batch(min_, max_, n): + return rng.uniform(min_, max_, size=(n, 1)) + + max_f = np.finfo(np.float64).max / 1e5 + min_f = max_f / 1e2 + stream_dim = 100 + batch_dim = 500000 + + X = gen_1d_uniform_batch(min_=min_f, max_=min_f, n=batch_dim) + + scaler = StandardScaler(with_std=False).fit(X) + + iscaler = StandardScaler(with_std=False) + batch = gen_1d_uniform_batch(min_=min_f, max_=min_f, n=batch_dim) + for _ in range(stream_dim): + iscaler = iscaler.partial_fit(batch) + + assert_allclose(iscaler.mean_, scaler.mean_) + + def test_scale_sparse_with_mean_raise_exception(): rng = np.random.RandomState(42) X = rng.randn(4, 5) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 218733145a0de..4d731fe676a31 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -696,13 +696,16 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): # old = stats until now # new = the current increment # updated = the aggregated stats - last_sum = last_mean * last_sample_count new_sum = np.nansum(X, axis=0) new_sample_count = np.sum(~np.isnan(X), axis=0) updated_sample_count = last_sample_count + new_sample_count - updated_mean = (last_sum + new_sum) / updated_sample_count + new_mean = new_sum / new_sample_count + delta = new_mean - last_mean + updated_mean = \ + last_mean + (delta * new_sample_count) / updated_sample_count # fixes test + # (last_mean * last_sample_count + new_sum) / updated_sample_count # breaks stability test if last_variance is None: updated_variance = None @@ -710,6 +713,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count last_unnormalized_variance = last_variance * last_sample_count + last_sum = last_mean * last_sample_count with np.errstate(divide='ignore', invalid='ignore'): last_over_new_count = last_sample_count / new_sample_count updated_unnormalized_variance = ( diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index ee08e016abe68..5373f5072a171 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -609,7 +609,7 @@ def naive_mean_variance_update(x, last_mean, last_variance, _incremental_mean_and_var(A1[i, :].reshape((1, A1.shape[1])), mean, var, n) assert_array_equal(n, A.shape[0]) - assert_array_almost_equal(A.mean(axis=0), mean) + assert_allclose(A.mean(axis=0), mean, rtol=1e-12) assert_greater(tol, np.abs(stable_var(A) - var).max())