|
1 | 1 | import math |
2 | 2 | from bisect import bisect_right |
| 3 | +from functools import partial |
| 4 | + |
3 | 5 | from .optimizer import Optimizer |
4 | 6 |
|
5 | 7 |
|
@@ -322,22 +324,30 @@ def _reduce_lr(self, epoch): |
322 | 324 | def in_cooldown(self): |
323 | 325 | return self.cooldown_counter > 0 |
324 | 326 |
|
325 | | - def _init_is_better(self, mode, threshold, threshold_mode): |
326 | | - if mode not in {'min', 'max'}: |
327 | | - raise ValueError('mode ' + mode + ' is unknown!') |
328 | | - if threshold_mode not in {'rel', 'abs'}: |
329 | | - raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') |
| 327 | + def _cmp(self, mode, threshold_mode, threshold, a, best): |
330 | 328 | if mode == 'min' and threshold_mode == 'rel': |
331 | 329 | rel_epsilon = 1. - threshold |
332 | | - self.is_better = lambda a, best: a < best * rel_epsilon |
333 | | - self.mode_worse = float('Inf') |
| 330 | + return a < best * rel_epsilon |
| 331 | + |
334 | 332 | elif mode == 'min' and threshold_mode == 'abs': |
335 | | - self.is_better = lambda a, best: a < best - threshold |
336 | | - self.mode_worse = float('Inf') |
| 333 | + return a < best - threshold |
| 334 | + |
337 | 335 | elif mode == 'max' and threshold_mode == 'rel': |
338 | 336 | rel_epsilon = threshold + 1. |
339 | | - self.is_better = lambda a, best: a > best * rel_epsilon |
340 | | - self.mode_worse = -float('Inf') |
| 337 | + return a > best * rel_epsilon |
| 338 | + |
341 | 339 | else: # mode == 'max' and epsilon_mode == 'abs': |
342 | | - self.is_better = lambda a, best: a > best + threshold |
343 | | - self.mode_worse = -float('Inf') |
| 340 | + return a > best + threshold |
| 341 | + |
| 342 | + def _init_is_better(self, mode, threshold, threshold_mode): |
| 343 | + if mode not in {'min', 'max'}: |
| 344 | + raise ValueError('mode ' + mode + ' is unknown!') |
| 345 | + if threshold_mode not in {'rel', 'abs'}: |
| 346 | + raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') |
| 347 | + |
| 348 | + if mode == 'min': |
| 349 | + self.mode_worse = float('inf') |
| 350 | + else: # mode == 'max': |
| 351 | + self.mode_worse = (-float('inf')) |
| 352 | + |
| 353 | + self.is_better = partial(self._cmp, mode, threshold_mode, threshold) |
0 commit comments