-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Improve stability of SGDClassifier / SGDRegressor with gradient clipping #3883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The l2 weight decay rescaling is also kept positive (or null) in case of strong regularization.
Can you bench this against master please? |
+1 for a benchmark otherwise looks good to me |
I am wondering wether there is a better way to compute the clipping in cython. |
LGTM. As long as you use |
The benchmark seems to show that the change is fine. Here is my script: import numpy as np
from time import time
from sklearn.linear_model import SGDClassifier
rng = np.random.RandomState(42)
n_samples = int(1e6)
data = rng.randn(n_samples, 100)
target = rng.randint(0, 2, n_samples)
durations = []
for i in range(10):
t0 = time()
SGDClassifier(n_iter=5, random_state=10).fit(data, target)
d = time() - t0
durations.append(d)
print("%0.3fs" % d)
print("%0.3f+/-%0.3fs" % (np.mean(durations), np.std(durations)))
|
Thanks @larsmans for the tip. Shall I merge? |
Improve stability of SGDClassifier / SGDRegressor with gradient clipping
Thanks! Let me add a whats_new.rst entry. |
Great job! |
The
squared_hinge
loss ofSGDClassifier
(and potentially thesquared
loss ofSGDRegressor
) tend to trigger numerical overflows even on normalized data for some hyper parameter combinations.This PR fixes that issue by clipping
dloss
to1e12
. All existing still tests pass.I have also had to prevent strong l2 regularization with large learning rates to trigger negative scales (which are meaningless and can also cause numerical divergence if lower than -1). Instead I set the weights to zero in that case. A new non regression tests highlights this case as well.
Both non regression tests were inspired by #3040. They both fail at epoch #2 and #3 of the iris data with the
sgd_fast.pyx
implementation from master.