Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0f0f29a

Browse files
authored
MNT refactoring of sgd utilities (#16528)
1 parent bbcfad8 commit 0f0f29a

File tree

2 files changed

+83
-274
lines changed

2 files changed

+83
-274
lines changed

sklearn/linear_model/_sgd_fast.pyx

Lines changed: 29 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -332,155 +332,39 @@ cdef class SquaredEpsilonInsensitive(Regression):
332332
return SquaredEpsilonInsensitive, (self.epsilon,)
333333

334334

335-
def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
336-
double intercept,
337-
LossFunction loss,
338-
int penalty_type,
339-
double alpha, double C,
340-
double l1_ratio,
341-
SequentialDataset dataset,
342-
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
343-
bint early_stopping, validation_score_cb,
344-
int n_iter_no_change,
345-
int max_iter, double tol, int fit_intercept,
346-
int verbose, bint shuffle, np.uint32_t seed,
347-
double weight_pos, double weight_neg,
348-
int learning_rate, double eta0,
349-
double power_t,
350-
double t=1.0,
351-
double intercept_decay=1.0):
352-
"""Plain SGD for generic loss functions and penalties.
353-
354-
Parameters
355-
----------
356-
weights : ndarray[double, ndim=1]
357-
The allocated coef_ vector.
358-
intercept : double
359-
The initial intercept.
360-
loss : LossFunction
361-
A concrete ``LossFunction`` object.
362-
penalty_type : int
363-
The penalty 2 for L2, 1 for L1, and 3 for Elastic-Net.
364-
alpha : float
365-
The regularization parameter.
366-
C : float
367-
Maximum step size for passive aggressive.
368-
l1_ratio : float
369-
The Elastic Net mixing parameter, with 0 <= l1_ratio <= 1.
370-
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
371-
dataset : SequentialDataset
372-
A concrete ``SequentialDataset`` object.
373-
validation_mask : ndarray[unsigned char, ndim=1]
374-
Equal to True on the validation set.
375-
early_stopping : boolean
376-
Whether to use a stopping criterion based on the validation set.
377-
validation_score_cb : callable
378-
A callable to compute a validation score given the current
379-
coefficients and intercept values.
380-
Used only if early_stopping is True.
381-
n_iter_no_change : int
382-
Number of iteration with no improvement to wait before stopping.
383-
max_iter : int
384-
The maximum number of iterations (epochs).
385-
tol: double
386-
The tolerance for the stopping criterion.
387-
fit_intercept : int
388-
Whether or not to fit the intercept (1 or 0).
389-
verbose : int
390-
Print verbose output; 0 for quite.
391-
shuffle : boolean
392-
Whether to shuffle the training data before each epoch.
393-
weight_pos : float
394-
The weight of the positive class.
395-
weight_neg : float
396-
The weight of the negative class.
397-
seed : np.uint32_t
398-
Seed of the pseudorandom number generator used to shuffle the data.
399-
learning_rate : int
400-
The learning rate:
401-
(1) constant, eta = eta0
402-
(2) optimal, eta = 1.0/(alpha * t).
403-
(3) inverse scaling, eta = eta0 / pow(t, power_t)
404-
(4) adaptive decrease
405-
(5) Passive Aggressive-I, eta = min(alpha, loss/norm(x))
406-
(6) Passive Aggressive-II, eta = 1.0 / (norm(x) + 0.5*alpha)
407-
eta0 : double
408-
The initial learning rate.
409-
power_t : double
410-
The exponent for inverse scaling learning rate.
411-
t : double
412-
Initial state of the learning rate. This value is equal to the
413-
iteration count except when the learning rate is set to `optimal`.
414-
Default: 1.0.
415-
intercept_decay : double
416-
The decay ratio of intercept, used in updating intercept.
417-
418-
Returns
419-
-------
420-
weights : array, shape=[n_features]
421-
The fitted weight vector.
422-
intercept : float
423-
The fitted intercept term.
424-
n_iter_ : int
425-
The actual number of iter (epochs).
426-
"""
427-
standard_weights, standard_intercept,\
428-
_, _, n_iter_ = _plain_sgd(weights,
429-
intercept,
430-
None,
431-
0,
432-
loss,
433-
penalty_type,
434-
alpha, C,
435-
l1_ratio,
436-
dataset,
437-
validation_mask,
438-
early_stopping,
439-
validation_score_cb,
440-
n_iter_no_change,
441-
max_iter, tol, fit_intercept,
442-
verbose, shuffle, seed,
443-
weight_pos, weight_neg,
444-
learning_rate, eta0,
445-
power_t,
446-
t,
447-
intercept_decay,
448-
0)
449-
return standard_weights, standard_intercept, n_iter_
450-
451-
452-
def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
453-
double intercept,
454-
np.ndarray[double, ndim=1, mode='c'] average_weights,
455-
double average_intercept,
456-
LossFunction loss,
457-
int penalty_type,
458-
double alpha, double C,
459-
double l1_ratio,
460-
SequentialDataset dataset,
461-
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
462-
bint early_stopping, validation_score_cb,
463-
int n_iter_no_change,
464-
int max_iter, double tol, int fit_intercept,
465-
int verbose, bint shuffle, np.uint32_t seed,
466-
double weight_pos, double weight_neg,
467-
int learning_rate, double eta0,
468-
double power_t,
469-
double t=1.0,
470-
double intercept_decay=1.0,
471-
int average=1):
472-
"""Average SGD for generic loss functions and penalties.
335+
def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
336+
double intercept,
337+
np.ndarray[double, ndim=1, mode='c'] average_weights,
338+
double average_intercept,
339+
LossFunction loss,
340+
int penalty_type,
341+
double alpha, double C,
342+
double l1_ratio,
343+
SequentialDataset dataset,
344+
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
345+
bint early_stopping, validation_score_cb,
346+
int n_iter_no_change,
347+
int max_iter, double tol, int fit_intercept,
348+
int verbose, bint shuffle, np.uint32_t seed,
349+
double weight_pos, double weight_neg,
350+
int learning_rate, double eta0,
351+
double power_t,
352+
double t=1.0,
353+
double intercept_decay=1.0,
354+
int average=0):
355+
"""SGD for generic loss functions and penalties with optional averaging
473356
474357
Parameters
475358
----------
476359
weights : ndarray[double, ndim=1]
477-
The allocated coef_ vector.
360+
The allocated vector of weights.
478361
intercept : double
479362
The initial intercept.
480363
average_weights : ndarray[double, ndim=1]
481-
The average weights as computed for ASGD
364+
The average weights as computed for ASGD. Should be None if average
365+
is 0.
482366
average_intercept : double
483-
The average intercept for ASGD
367+
The average intercept for ASGD. Should be 0 if average is 0.
484368
loss : LossFunction
485369
A concrete ``LossFunction`` object.
486370
penalty_type : int
@@ -549,55 +433,14 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
549433
intercept : float
550434
The fitted intercept term.
551435
average_weights : array shape=[n_features]
552-
The averaged weights across iterations
436+
The averaged weights across iterations. Values are valid only if
437+
average > 0.
553438
average_intercept : float
554-
The averaged intercept across iterations
439+
The averaged intercept across iterations.
440+
Values are valid only if average > 0.
555441
n_iter_ : int
556442
The actual number of iter (epochs).
557443
"""
558-
return _plain_sgd(weights,
559-
intercept,
560-
average_weights,
561-
average_intercept,
562-
loss,
563-
penalty_type,
564-
alpha, C,
565-
l1_ratio,
566-
dataset,
567-
validation_mask,
568-
early_stopping,
569-
validation_score_cb,
570-
n_iter_no_change,
571-
max_iter, tol, fit_intercept,
572-
verbose, shuffle, seed,
573-
weight_pos, weight_neg,
574-
learning_rate, eta0,
575-
power_t,
576-
t,
577-
intercept_decay,
578-
average)
579-
580-
581-
def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
582-
double intercept,
583-
np.ndarray[double, ndim=1, mode='c'] average_weights,
584-
double average_intercept,
585-
LossFunction loss,
586-
int penalty_type,
587-
double alpha, double C,
588-
double l1_ratio,
589-
SequentialDataset dataset,
590-
np.ndarray[unsigned char, ndim=1, mode='c'] validation_mask,
591-
bint early_stopping, validation_score_cb,
592-
int n_iter_no_change,
593-
int max_iter, double tol, int fit_intercept,
594-
int verbose, bint shuffle, np.uint32_t seed,
595-
double weight_pos, double weight_neg,
596-
int learning_rate, double eta0,
597-
double power_t,
598-
double t=1.0,
599-
double intercept_decay=1.0,
600-
int average=0):
601444

602445
# get the data information into easy vars
603446
cdef Py_ssize_t n_samples = dataset.n_samples

0 commit comments

Comments
 (0)