Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b280798
FIX use objective instead of loss for convergence in SGD
glemaitre Oct 8, 2024
5164376
remove debug
glemaitre Oct 8, 2024
8c38943
add more info regarding the objective or validation loss
glemaitre Oct 8, 2024
9af0140
BUG: fix termination criterion of SGD, use objective instead of loss
kostayScr Jul 30, 2025
af0270b
added tests
kostayScr Jul 30, 2025
ff65305
fixed code formatting
kostayScr Jul 30, 2025
af9f2c1
Merge branch 'main' into is/30027
kostayScr Jul 30, 2025
611e7cd
fixed code formatting, number 2
kostayScr Jul 30, 2025
10aa131
fixes test, so that SGD converges, by setting tol=None
kostayScr Jul 31, 2025
66e133a
fix doctest for SGDOneClassSVM
kostayScr Jul 31, 2025
3c9032f
fixed L1 norm accumulation in WeightVector
kostayScr Jul 31, 2025
6099adf
fallback to loss for PA1/PA2; respect penalty_type
kostayScr Jul 31, 2025
4aea4a3
remove debug
kostayScr Jul 31, 2025
4ac8494
fix typo
kostayScr Aug 4, 2025
35eb695
added changelog entry
kostayScr Aug 29, 2025
55846bc
Merge branch 'main' into is/30027
kostayScr Aug 30, 2025
36da843
modified changelog
kostayScr Sep 3, 2025
90be844
Merge branch 'is/30027' of https://github.com/kostayScr/scikit-learn …
kostayScr Sep 3, 2025
d8c9a44
refactor to remove variable
kostayScr Sep 3, 2025
09085b1
Update doc/whats_new/upcoming_changes/sklearn.linear_model/31856.fix.rst
kostayScr Sep 9, 2025
3584744
Update sklearn/linear_model/_sgd_fast.pyx.tp
kostayScr Sep 9, 2025
260c8d0
update comment about 0.5 coef of L2 reg term due to weight decay
kostayScr Sep 9, 2025
d83e0d8
update comment
kostayScr Sep 9, 2025
e89cbc6
update test comment
kostayScr Sep 9, 2025
c0a25b4
remove unused line from test
kostayScr Sep 9, 2025
a4d1b19
update test comments
kostayScr Sep 9, 2025
ba17482
FIX WeightVector norm accumulation
kostayScr Sep 11, 2025
1a675a5
update comment
kostayScr Sep 11, 2025
019ebca
update test comment
kostayScr Sep 11, 2025
9627cd1
update comment about nu / 2
kostayScr Sep 11, 2025
9c903af
update changelog
kostayScr Sep 11, 2025
088bfc4
refactor loss addition
kostayScr Sep 11, 2025
b8c3771
add comment
kostayScr Sep 11, 2025
2b23223
Update sklearn/linear_model/_sgd_fast.pyx.tp
kostayScr Sep 12, 2025
f03c622
rename variable
kostayScr Sep 12, 2025
41b1081
Update changelog
kostayScr Sep 12, 2025
d55c5af
fix typo in comment
kostayScr Oct 1, 2025
8e61f02
remove unnecessary test assert message
kostayScr Oct 1, 2025
a6a9367
Merge branch 'main' into is/30027
OmarManzoor Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- Fix the convergence criteria for SGD models, to avoid premature convergence when
`tol != None`. This primarily impacts :class:`SGDOneClassSVM` but also affects
:class:`SGDClassifier` and :class:`SGDRegressor`. Before this fix, only the loss
function without penalty was used as the convergence check, whereas now, the full
objective with regularization is used.
By :user:`Guillaume Lemaitre <glemaitre>` and :user:`kostayScr <kostayScr>`
60 changes: 45 additions & 15 deletions sklearn/linear_model/_sgd_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,10 @@ def _plain_sgd{{name_suffix}}(
cdef double update = 0.0
cdef double intercept_update = 0.0
cdef double sumloss = 0.0
cdef double cur_loss_val = 0.0
cdef double score = 0.0
cdef double best_loss = INFINITY
cdef double objective_sum = 0.0
cdef double best_objective = INFINITY
cdef double best_score = -INFINITY
cdef {{c_type}} y = 0.0
cdef {{c_type}} sample_weight
Expand Down Expand Up @@ -465,6 +467,7 @@ def _plain_sgd{{name_suffix}}(
with nogil:
for epoch in range(max_iter):
sumloss = 0
objective_sum = 0
if verbose > 0:
with gil:
print("-- Epoch %d" % (epoch + 1))
Expand All @@ -486,7 +489,23 @@ def _plain_sgd{{name_suffix}}(
eta = eta0 / pow(t, power_t)

if verbose or not early_stopping:
sumloss += loss.cy_loss(y, p)
cur_loss_val = loss.cy_loss(y, p)
sumloss += cur_loss_val
objective_sum += cur_loss_val
# for PA1/PA2 (passive/aggressive model, online algorithm) use only the loss
if learning_rate != PA1 and learning_rate != PA2:
# sum up all the terms in the optimization objective function
# (i.e. also include regularization in addition to the loss)
# Note: for the L2 term SGD optimizes 0.5 * L2**2, due to using
# weight decay that's why the 0.5 coefficient is required
if penalty_type > 0: # if regularization is enabled
objective_sum += alpha * (
(1 - l1_ratio) * 0.5 * w.norm() ** 2 +
l1_ratio * w.l1norm()
)
if one_class: # specific to One-Class SVM
# nu is alpha * 2 (alpha is set as nu / 2 by the caller)
objective_sum += intercept * (alpha * 2)

if y > 0.0:
class_weight = weight_pos
Expand Down Expand Up @@ -552,16 +571,6 @@ def _plain_sgd{{name_suffix}}(
t += 1
count += 1

# report epoch information
if verbose > 0:
with gil:
print("Norm: %.2f, NNZs: %d, Bias: %.6f, T: %d, "
"Avg. loss: %f"
% (w.norm(), np.nonzero(weights)[0].shape[0],
intercept, count, sumloss / train_count))
print("Total training time: %.2f seconds."
% (time() - t_start))

# floating-point under-/overflow check.
if (not isfinite(intercept) or any_nonfinite(weights)):
infinity = True
Expand All @@ -571,6 +580,14 @@ def _plain_sgd{{name_suffix}}(
if early_stopping:
with gil:
score = validation_score_cb(weights.base, intercept)
if verbose > 0: # report epoch information
print("Norm: %.2f, NNZs: %d, Bias: %.6f, T: %d, "
"Avg. loss: %f, Objective: %f, Validation score: %f"
% (w.norm(), np.nonzero(weights)[0].shape[0],
intercept, count, sumloss / train_count,
objective_sum / train_count, score))
print("Total training time: %.2f seconds."
% (time() - t_start))
if tol > -INFINITY and score < best_score + tol:
no_improvement_count += 1
else:
Expand All @@ -579,12 +596,25 @@ def _plain_sgd{{name_suffix}}(
best_score = score
# or evaluate the loss on the training set
else:
if tol > -INFINITY and sumloss > best_loss - tol * train_count:
if verbose > 0: # report epoch information
with gil:
print("Norm: %.2f, NNZs: %d, Bias: %.6f, T: %d, "
"Avg. loss: %f, Objective: %f"
% (w.norm(), np.nonzero(weights)[0].shape[0],
intercept, count, sumloss / train_count,
objective_sum / train_count))
print("Total training time: %.2f seconds."
% (time() - t_start))
# true objective = objective_sum / number of samples
if (
tol > -INFINITY
and objective_sum / train_count > best_objective - tol
):
no_improvement_count += 1
else:
no_improvement_count = 0
if sumloss < best_loss:
best_loss = sumloss
if objective_sum / train_count < best_objective:
best_objective = objective_sum / train_count

# if there is no improvement several times in a row
if no_improvement_count >= n_iter_no_change:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,9 +2252,9 @@ class SGDOneClassSVM(OutlierMixin, BaseSGD):
>>> import numpy as np
>>> from sklearn import linear_model
>>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
>>> clf = linear_model.SGDOneClassSVM(random_state=42)
>>> clf = linear_model.SGDOneClassSVM(random_state=42, tol=None)
>>> clf.fit(X)
SGDOneClassSVM(random_state=42)
SGDOneClassSVM(random_state=42, tol=None)

>>> print(clf.predict([[4, 4]]))
[1]
Expand Down
47 changes: 47 additions & 0 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,53 @@ def test_ocsvm_vs_sgdocsvm():
assert corrcoef >= 0.9


def test_sgd_oneclass_convergence():
# Check that the optimization does not end early and that the stopping criterion
# is working. Non-regression test for #30027
for nu in [0.1, 0.5, 0.9]:
# no need for large max_iter
model = SGDOneClassSVM(
nu=nu, max_iter=100, tol=1e-3, learning_rate="constant", eta0=1e-3
)
model.fit(iris.data)
# 6 is the minimal number of iterations that should be surpassed, after which
# the optimization can stop
assert model.n_iter_ > 6


def test_sgd_oneclass_vs_linear_oneclass():
# Test convergence vs. liblinear `OneClassSVM` with kernel="linear"
for nu in [0.1, 0.5, 0.9]:
# allow enough iterations, small dataset
model = SGDOneClassSVM(
nu=nu, max_iter=20000, tol=None, learning_rate="constant", eta0=1e-3
)
model_ref = OneClassSVM(kernel="linear", nu=nu, tol=1e-6) # reference model
model.fit(iris.data)
model_ref.fit(iris.data)

preds = model.predict(iris.data)
dec_fn = model.decision_function(iris.data)

preds_ref = model_ref.predict(iris.data)
dec_fn_ref = model_ref.decision_function(iris.data)

dec_fn_corr = np.corrcoef(dec_fn, dec_fn_ref)[0, 1]
preds_corr = np.corrcoef(preds, preds_ref)[0, 1]
# check weights and intercept concatenated together for correlation
coef_corr = np.corrcoef(
np.concatenate([model.coef_, -model.offset_]),
np.concatenate([model_ref.coef_.flatten(), model_ref.intercept_]),
)[0, 1]
# share of predicted 1's
share_ones = (preds == 1).sum() / len(preds)

assert dec_fn_corr > 0.99
assert preds_corr > 0.95
assert coef_corr > 0.99
assert_allclose(1 - share_ones, nu)


def test_l1_ratio():
# Test if l1 ratio extremes match L1 and L2 penalty settings.
X, y = datasets.make_classification(
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,14 @@ def test_multi_output_classification_partial_fit_sample_weights():
Xw = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
yw = [[3, 2], [2, 3], [3, 2]]
w = np.asarray([2.0, 1.0, 1.0])
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20)
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20, tol=None)
clf_w = MultiOutputClassifier(sgd_linear_clf)
clf_w.fit(Xw, yw, w)

# unweighted, but with repeated samples
X = [[1, 2, 3], [1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
y = [[3, 2], [3, 2], [2, 3], [3, 2]]
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20)
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20, tol=None)
clf = MultiOutputClassifier(sgd_linear_clf)
clf.fit(X, y)
X_test = [[1.5, 2.5, 3.5]]
Expand Down
2 changes: 2 additions & 0 deletions sklearn/utils/_weight_vector.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ cdef class WeightVector{{name_suffix}}(object):
cdef double average_b
cdef int n_features
cdef double sq_norm
cdef double l1_norm

cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr,
int xnnz, {{c_type}} c) noexcept nogil
Expand All @@ -41,5 +42,6 @@ cdef class WeightVector{{name_suffix}}(object):
cdef void scale(self, {{c_type}} c) noexcept nogil
cdef void reset_wscale(self) noexcept nogil
cdef {{c_type}} norm(self) noexcept nogil
cdef {{c_type}} l1norm(self) noexcept nogil

{{endfor}}
30 changes: 20 additions & 10 deletions sklearn/utils/_weight_vector.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ dtypes = [('64', 'double', 1e-9),

cimport cython
from libc.limits cimport INT_MAX
from libc.math cimport sqrt
from libc.math cimport sqrt, fabs

from sklearn.utils._cython_blas cimport _dot, _scal, _axpy
from sklearn.utils._cython_blas cimport _dot, _scal, _axpy, _asum

{{for name_suffix, c_type, reset_wscale_threshold in dtypes}}

Expand All @@ -53,6 +53,8 @@ cdef class WeightVector{{name_suffix}}(object):
The number of features (= dimensionality of ``w``).
sq_norm : {{c_type}}
The squared norm of ``w``.
l1_norm : {{c_type}}
The L1 norm of ``w``.
"""

def __cinit__(self,
Expand All @@ -67,6 +69,7 @@ cdef class WeightVector{{name_suffix}}(object):
self.wscale = 1.0
self.n_features = w.shape[0]
self.sq_norm = _dot(self.n_features, self.w_data_ptr, 1, self.w_data_ptr, 1)
self.l1_norm = _asum(self.n_features, self.w_data_ptr, 1)

self.aw = aw
if self.aw is not None:
Expand All @@ -78,7 +81,7 @@ cdef class WeightVector{{name_suffix}}(object):
{{c_type}} c) noexcept nogil:
"""Scales sample x by constant c and adds it to the weight vector.

This operation updates ``sq_norm``.
This operation updates ``sq_norm`` and ``l1_norm``.

Parameters
----------
Expand All @@ -94,8 +97,8 @@ cdef class WeightVector{{name_suffix}}(object):
cdef int j
cdef int idx
cdef double val
cdef double innerprod = 0.0
cdef double xsqnorm = 0.0
cdef double l2norm_accumulator = 0.0
cdef double l1norm_accumulator = 0.0

# the next two lines save a factor of 2!
cdef {{c_type}} wscale = self.wscale
Expand All @@ -104,11 +107,13 @@ cdef class WeightVector{{name_suffix}}(object):
for j in range(xnnz):
idx = x_ind_ptr[j]
val = x_data_ptr[j]
innerprod += (w_data_ptr[idx] * val)
xsqnorm += (val * val)
w_data_ptr[idx] += val * (c / wscale)

self.sq_norm += (xsqnorm * c * c) + (2.0 * innerprod * wscale * c)
l2norm_accumulator += w_data_ptr[idx] * w_data_ptr[idx]
l1norm_accumulator += fabs(w_data_ptr[idx])

self.sq_norm = l2norm_accumulator * (wscale * wscale)
self.l1_norm = l1norm_accumulator * wscale

# Update the average weights according to the sparse trick defined
# here: https://research.microsoft.com/pubs/192769/tricks-2012.pdf
Expand Down Expand Up @@ -180,10 +185,11 @@ cdef class WeightVector{{name_suffix}}(object):
cdef void scale(self, {{c_type}} c) noexcept nogil:
"""Scales the weight vector by a constant ``c``.

It updates ``wscale`` and ``sq_norm``. If ``wscale`` gets too
small we call ``reset_swcale``."""
It updates ``wscale``, ``sq_norm``, and ``l1_norm``. If ``wscale`` gets too
small we call ``reset_wscale``."""
self.wscale *= c
self.sq_norm *= (c * c)
self.l1_norm *= fabs(c)

if self.wscale < {{reset_wscale_threshold}}:
self.reset_wscale()
Expand All @@ -204,4 +210,8 @@ cdef class WeightVector{{name_suffix}}(object):
"""The L2 norm of the weight vector. """
return sqrt(self.sq_norm)

cdef {{c_type}} l1norm(self) noexcept nogil:
"""The L1 norm of the weight vector. """
return self.l1_norm

{{endfor}}
Loading