Thanks to visit codestin.com
Credit goes to github.com

Skip to content

DOC improve inline comments in SAGA #25100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions sklearn/linear_model/_sag_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,29 @@ def sag{{name_suffix}}(

Used in Ridge and LogisticRegression.

Some implementation details:

- Just-in-time (JIT) update: In SAG(A), the average-gradient update is
collinear with the drawn sample X_i. Therefore, if the data is sparse, the
random sample X_i will change the average gradient only on features j where
X_ij != 0. In some cases, the average gradient on feature j might change
only after k random samples with no change. In these cases, instead of
applying k times the same gradient step on feature j, we apply the gradient
step only once, scaled by k. This is called the "just-in-time update", and
it is performed in `lagged_update{{name_suffix}}`. This function also
applies the proximal operator after the gradient step (if L1 regularization
is used in SAGA).

- Weight scale: In SAG(A), the weights are scaled down at each iteration
due to the L2 regularization. To avoid updating all the weights at each
iteration, the weight scale is factored out in a separate variable `wscale`
which is only used in the JIT update. When this variable is too small, it
is reset for numerical stability using the function
`scale_weights{{name_suffix}}`. This reset requires applying all remaining
JIT updates. This reset is also performed every `n_samples` iterations
before each convergence check, so when the algorithm stops, we are sure
that there is no remaining JIT updates.

Reference
---------
Schmidt, M., Roux, N. L., & Bach, F. (2013).
Expand Down Expand Up @@ -368,7 +391,7 @@ def sag{{name_suffix}}(
num_seen += 1
seen_init[sample_ind] = 1

# make the weight updates
# make the weight updates (just-in-time gradient step, and prox operator)
if sample_itr > 0:
status = lagged_update{{name_suffix}}(
weights=weights,
Expand Down Expand Up @@ -420,6 +443,11 @@ def sag{{name_suffix}}(
val * (gradient[class_ind] -
gradient_memory[s_idx + class_ind])
if saga:
# Note that this is not the main gradient step,
# which is performed just-in-time in lagged_update.
# This part is done outside the JIT update
# as it does not depend on the average gradient.
# The prox operator is applied after the JIT update
weights[f_idx + class_ind] -= \
(gradient_correction * step_size
* (1 - 1. / num_seen) / wscale)
Expand Down Expand Up @@ -467,6 +495,7 @@ def sag{{name_suffix}}(
(cumulative_sums_prox[sample_itr - 1] +
step_size * beta / wscale)
# If wscale gets too small, we need to reset the scale.
# This also resets the just-in-time update system.
if wscale < 1e-9:
if verbose:
with gil:
Expand All @@ -493,8 +522,10 @@ def sag{{name_suffix}}(
if status == -1:
break

# we scale the weights every n_samples iterations and reset the
# We scale the weights every n_samples iterations and reset the
# just-in-time update system for numerical stability.
# Because this reset is done before every convergence check, we are
# sure there is no remaining lagged update when the algorithm stops.
status = scale_weights{{name_suffix}}(
weights=weights,
wscale=&wscale,
Expand All @@ -509,9 +540,9 @@ def sag{{name_suffix}}(
sum_gradient=sum_gradient,
n_iter=n_iter
)

if status == -1:
break

# check if the stopping criteria is reached
max_change = 0.0
max_weight = 0.0
Expand Down Expand Up @@ -565,7 +596,10 @@ cdef int scale_weights{{name_suffix}}(
{{c_type}}* sum_gradient,
int n_iter
) nogil:
"""Scale the weights with wscale for numerical stability.
"""Scale the weights and reset wscale to 1.0 for numerical stability, and
reset the just-in-time (JIT) update system.

See `sag{{name_suffix}}`'s docstring about the JIT update system.

wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
can become very small, so we reset it every n_samples iterations to 1.0 for
Expand Down Expand Up @@ -618,6 +652,9 @@ cdef int lagged_update{{name_suffix}}(
int n_iter
) nogil:
"""Hard perform the JIT updates for non-zero features of present sample.

See `sag{{name_suffix}}`'s docstring about the JIT update system.

The updates that awaits are kept in memory using cumulative_sums,
cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
(Defazio et al. 2014) for details. If reset=True, we also reset wscale to
Expand Down