-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[WIP] use more robust mean online computation in StandardScaler #11549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -609,7 +609,7 @@ def naive_mean_variance_update(x, last_mean, last_variance, | |||
_incremental_mean_and_var(A1[i, :].reshape((1, A1.shape[1])), | |||
mean, var, n) | |||
assert_array_equal(n, A.shape[0]) | |||
assert_array_almost_equal(A.mean(axis=0), mean) | |||
assert_allclose(A.mean(axis=0), mean, rtol=1e-12) | |||
assert_greater(tol, np.abs(stable_var(A) - var).max()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test comes from agramfort@6a5a2f7 by @giorgiop
delta = new_mean - last_mean | ||
updated_mean = \ | ||
last_mean + (delta * new_sample_count) / updated_sample_count # fixes test | ||
# (last_mean * last_sample_count + new_sum) / updated_sample_count # breaks stability test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is line is how it was before and is necessary to have the stability test below to pass. However this line is the problem for the issue reported as last_mean * last_sample_count
will explode.
So I feel a bit stuck. Either I change the test or ... Any thought?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you're saying last_mean + (delta * new_sample_count) / updated_sample_count
is unstable, but avoids overflow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes on this issue: (In the context of our test_incremental_variance_numerical_stability
test)
The mean in this PR causes the tolerance in the variance to go to ~458, which is up from the original 177. All it takes to get ~458 is by factoring out last_sample_count
:
(last_mean + new_sum/last_sample_count)*last_sample_count / updated_sample_count
This PRs updated_means differ from master on the order of 10-9, which is enough to cause the instability in the variance. (The variance is on the order of 10+15)
Is this something you're still completing, @agramfort? |
@jnothman please take over as it cannot be priority for me right now :( |
feel free to push a fix directly on my branch
… |
Hello folks, Any workaround meanwhile this gets fixed? I do not have an online source, hence I would not bother using batch flavor. But sure if that is an option in StandardScaler ... or maybe I should just to the scaling manually on numpy or something meanwhile. Thanks. |
This PR uses the parallel algorithm which is less numerically stable as described here: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm |
indeed. As far as I remember my fix was breaking some test. Either we relax the tests or I am not sure what to do here. |
Reference Issues/PRs
Closes #5602
What does this implement/fix? Explain your changes.
use parallel algo from https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
I still new to check if same problem occurs on sparse matrices and for variance... which I suspect