Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d2ff733

Browse files
elanmartsoumith
authored andcommitted
Make ReduceLROnPlateau serializable. (#5300)
* replace lambdas with partial * flake8
1 parent 5964700 commit d2ff733

2 files changed

Lines changed: 24 additions & 13 deletions

File tree

test/test_optim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,5 +618,6 @@ def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, ver
618618
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
619619
epoch, target[epoch], param_group['lr']), delta=1e-5)
620620

621+
621622
if __name__ == '__main__':
622623
run_tests()

torch/optim/lr_scheduler.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import math
22
from bisect import bisect_right
3+
from functools import partial
4+
35
from .optimizer import Optimizer
46

57

@@ -322,22 +324,30 @@ def _reduce_lr(self, epoch):
322324
def in_cooldown(self):
323325
return self.cooldown_counter > 0
324326

325-
def _init_is_better(self, mode, threshold, threshold_mode):
326-
if mode not in {'min', 'max'}:
327-
raise ValueError('mode ' + mode + ' is unknown!')
328-
if threshold_mode not in {'rel', 'abs'}:
329-
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
327+
def _cmp(self, mode, threshold_mode, threshold, a, best):
330328
if mode == 'min' and threshold_mode == 'rel':
331329
rel_epsilon = 1. - threshold
332-
self.is_better = lambda a, best: a < best * rel_epsilon
333-
self.mode_worse = float('Inf')
330+
return a < best * rel_epsilon
331+
334332
elif mode == 'min' and threshold_mode == 'abs':
335-
self.is_better = lambda a, best: a < best - threshold
336-
self.mode_worse = float('Inf')
333+
return a < best - threshold
334+
337335
elif mode == 'max' and threshold_mode == 'rel':
338336
rel_epsilon = threshold + 1.
339-
self.is_better = lambda a, best: a > best * rel_epsilon
340-
self.mode_worse = -float('Inf')
337+
return a > best * rel_epsilon
338+
341339
else: # mode == 'max' and epsilon_mode == 'abs':
342-
self.is_better = lambda a, best: a > best + threshold
343-
self.mode_worse = -float('Inf')
340+
return a > best + threshold
341+
342+
def _init_is_better(self, mode, threshold, threshold_mode):
343+
if mode not in {'min', 'max'}:
344+
raise ValueError('mode ' + mode + ' is unknown!')
345+
if threshold_mode not in {'rel', 'abs'}:
346+
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
347+
348+
if mode == 'min':
349+
self.mode_worse = float('inf')
350+
else: # mode == 'max':
351+
self.mode_worse = (-float('inf'))
352+
353+
self.is_better = partial(self._cmp, mode, threshold_mode, threshold)

0 commit comments

Comments
 (0)