Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@yuta0821
Copy link
Contributor

@yuta0821 yuta0821 commented Feb 27, 2022

Fixes #2465

Description:
Based on discussion in the #2465, I updated the code so that the LRScheduler will attach to Events.STARTED.
Main Contribution

  • ignite/handlers/param_scheduler.py
    • add use_legacy on class LRScheduler and explanation of it
    • update example code
    • updata __call__ method
    • update simulate_values such as change the timing of __call__ method of scheduler
    • update create_lr_scheduler_with_warmup
  • ignite/contrib/engines/common.py
    • remove some code to adjust to this change
  • tests/ignite/handlers/test_param_scheduler.py
    • add some test to make sure that the use_legacy works

Matters for consultation
Currently implementing the following part using use_legacy = True, but feel it is not a very smart implemantation.

# ignite/contrib/engines/common.py 1022
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history, use_legacy=True)

However, in the current implementation, we need to set the initial learning rate at the end of each milestone, and the change for issue #2465 in implementation will probably result in one extra initial learning rate. I couldn't think of an easy way to fix this. I think if don't take use_legacy=True, probably may need to change the code in many places.

I ran the following test and pass them

  • bash tests/run_cpu_tests.sh
  • bash ./tests/run_code_style.sh mypy
  • bash ./tests/run_code_style.sh fmt

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added module: contrib Contrib module module: handlers Core Handlers module labels Feb 27, 2022
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @yuta0821 !
I left few comments and will check in details the modifications in tests.

milestones_values.append((warmup_duration, init_lr))

lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history, use_legacy=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check if we could use use_legacy=False ? Or what is blocking here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use_legacy=False, the result look like the example below

  • example1
from torch.optim.lr_scheduler import ExponentialLR

tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0.01)
torch_lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

trainer = Engine(dummy_update)
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_end_value=0.1,
                                            warmup_duration=3)

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])

_ = trainer.run([0] * 8, max_epochs=1)
  • output1
0.0
0.05
0.1
0.01
0.01
0.01
0.01
0.001
  • example2
from torch.optim.lr_scheduler import ExponentialLR

tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0.1)
torch_lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

trainer = Engine(dummy_update)
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_end_value=0.1,
                                            warmup_duration=3)

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])

_ = trainer.run([0] * 8, max_epochs=1)
  • output2
0.0
0.05
0.1
0.1
0.1
0.1
0.010000000000000002
0.010000000000000002

The initial value of stepLR (= 0.1 or 0.01) is counted one extra time than step_size (=3).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details @yuta0821 , can you please debug this a bit more and explicitly say which scheduler is responsible for adding LR value (0.1 or 0.01). I'm not quite sure to understand why exactly this happens. Thanks !

Copy link
Contributor Author

@yuta0821 yuta0821 Feb 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Thanks a lot for your comment !
Consider the case where warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=3.
In this case, milestones_values = [(0, 0.0), (2, 0.1)]
If the initial value of lr in the optimizer is different from the warmup_end_value, it is necessary to add the initial value to the end of milestones. Therefore, milestones_values = [(0, 0.0), (2, 0.1), (3, initial value of lr)]
This is because the LRScheduler updates the lr starting from the last value of the milestones_values.
After that the following code is executed, resulting in repeating the initial value of lr.

super(LRScheduler, self).__call__(engine, name)
self.lr_scheduler.last_epoch += 1  # type: ignore[attr-defined]

Even if the initial value of lr in the optimizer is equal to the warmup_end_value, then the initial value of lr will be called extra once.

In the end, since the first __call__ method of LRScheduler runs with reference to the last value of milestones_values, the last value of milestones_values plus the initial value of LRScheduler are duplicated.

If we adjust this bug without use_legacy=False, we may have to change a lot of code such like one related to the PeacewiseLinear , which is beyond the scope of this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuta0821 sorry for delay and thanks for the explanation! I haven't checked it in details but will do as soon as it could be possible from my side (~4-5 days).
@sdesrozis can you help with that if you have some bandwidth ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It performs the same operation as use_legacy, but wouldn't it be preferable to add the argument skip_initial_value as a variable to be used for the internal function ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, expected rather than excepted 😅

As far I understand, using use_legacy=False, the first lr comes from the optimizer. Whatever schedulers used, the schedulers concatenation will produce a repetition at each joint.

Having an internal option as you suggested sounds good to me. I mean rename use_legacy to skip_initial_value is fine. Although, we have to keep use_legacy for the users.

What do you think ?

Copy link
Contributor Author

@yuta0821 yuta0821 Mar 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesrozis
I am sorry for having consulted with you.
It seems that this problem can be solved by setting the internal variable keep_first_value=True only when create_lr_scheduler_with_warmup is called, as shown below, to store the initial value in LRScheduler.

class LRScheduler(ParamScheduler):
    def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False, use_legacy: bool = False, keep_first_lr: bool = False):
        if keep_first_lr:
            self.lr_scheduler._get_lr_called_within_step = True  # type: ignore[attr-defined]
            self.first_lr = self.lr_scheduler.get_lr()
            self.lr_scheduler._get_lr_called_within_step = False  # type: ignore[attr-defined]

    def get_param(self) -> Union[float, List[float]]:
        """Method to get current optimizer's parameter value"""
        # Emulate context manager for pytorch>=1.4
        if hasattr(self, "first_lr"):
            lr_list = self.first_lr
            del self.first_lr
        else:
       
def create_lr_scheduler_with_warmup( ):
    if isinstance(lr_scheduler, _LRScheduler):
        init_lr = param_group["lr"]
        lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history, keep_first_lr=True)
    else:
        init_lr = lr_scheduler.get_param()

I am running the existing test now. I will commit once all tests pass ! -> Done !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the tests are ko...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I added type: ignore[attr-defined] !

@sdesrozis
Copy link
Contributor

sdesrozis commented Mar 7, 2022

@yuta0821 Thanks very much for this work. I have to think about the api though. The use_legacy arg is for BC (so to be removed one day) but the internal only keep_fist_lr makes me a little doubtful. This arg is only required using ConcatScheduler, could we do smarter (e.g. without exposing) ? Anyway, it works and that is pretty cool. Let's just think about it. I suppose that @vfdev-5 would have an opinion on that.

@yuta0821
Copy link
Contributor Author

yuta0821 commented Mar 7, 2022

@sdesrozis Yes, I agree, I don't want to add more args without any thought. However, I haven't yet figured out how to control the lr value of an LRScheduler bound to a ConcatScheduler without using keep_first_lr. Originally, this was implemented by adjusting the milestones_values of PeacewiseLinear, but chage of this PR no longer works. I'll see if there's a smarter way to do this.

@sdesrozis
Copy link
Contributor

sdesrozis commented Mar 8, 2022

My toughts are the target would be having the following code working fine

from torch.optim.lr_scheduler import StepLR

scheduler1 = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler1 = LRScheduler(scheduler1)

scheduler2 = StepLR(optimizer, step_size=3, gamma=0.01)
scheduler2 = LRScheduler(scheduler2)

scheduler = ConcatScheduler(schedulers=[scheduler1, scheduler2], durations=[4, ])

Whether use_legacy is used or not for schedulers, it should produce the same output. It means ConcatScheduler should enable keep_first_lr of LRScheduler internally without any user api. The minimal would be to enable using a method rather than at init.

scheduler2 = LRScheduler(scheduler2)
scheduler2.keep_first_lr = True

I don't know yet...

Extra question, what happen if scheduler1 is built using use_legacy=True and scheduler2 isn't ?

@yuta0821
Copy link
Contributor Author

yuta0821 commented Mar 8, 2022

@sdesrozis I'd like to confirm the behavior of following code using use_legacy=True. Is this output is right ?

from torch.optim.lr_scheduler import StepLR

optimizer = torch.optim.SGD([tensor], lr=0.01)
torch_scheduler1 = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler1 = LRScheduler(torch_scheduler1, use_legacy=True)

torch_scheduler2 = StepLR(optimizer, step_size=3, gamma=0.01)
scheduler2 = LRScheduler(torch_scheduler2, use_legacy=True)

scheduler = ConcatScheduler(schedulers=[scheduler1, scheduler2], durations=[4, ])

def dummy_update(engine, batch):
    pass
trainer = Engine(dummy_update)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

_ = trainer.run([0] * 9, max_epochs=1)
0.01
0.01
0.01
0.001
0.001
0.001
0.001
1e-05
1e-05

Whether use_legacy is used or not for schedulers, it should produce the same output. It means ConcatScheduler should enable keep_first_lr of LRScheduler internally without any user api.

OK, I understand what message means. Thank you for the clear explanation. Surely, we must enable keep_first_lr of LRScheduler internally without any user api.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 10, 2022

@sdesrozis I'd like to confirm the behavior of following code using use_legacy=True. Is this output is right ?

from torch.optim.lr_scheduler import StepLR

optimizer = torch.optim.SGD([tensor], lr=0.01)
torch_scheduler1 = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler1 = LRScheduler(torch_scheduler1, use_legacy=True)

torch_scheduler2 = StepLR(optimizer, step_size=3, gamma=0.01)
scheduler2 = LRScheduler(torch_scheduler2, use_legacy=True)

scheduler = ConcatScheduler(schedulers=[scheduler1, scheduler2], durations=[4, ])

def dummy_update(engine, batch):
    pass
trainer = Engine(dummy_update)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

_ = trainer.run([0] * 9, max_epochs=1)
0.01
0.01
0.01
0.001
0.001
0.001
0.001
1e-05
1e-05

@yuta0821 the output looks correct to me. We have

--- scheduler 1 
0.01
0.01
0.01
0.001
--- scheduler 2
0.001
0.001
0.001
1e-05
1e-05

What happens if using use_legacy=False and attaching to appropriate event ?

@yuta0821
Copy link
Contributor Author

@vfdev-5
Sorry for late reply.

What happens if using use_legacy=False and attaching to appropriate event ?

If set trainer.add_event_handler to the proper position, the behavior with use_legacy = False becomes as follows.
With keep_lr_first, the behavior of this code is the same as with use_legacy = True.

from torch.optim.lr_scheduler import StepLR

optimizer = torch.optim.SGD([tensor], lr=0.01)
torch_scheduler1 = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler1 = LRScheduler(torch_scheduler1, use_legacy=False)

torch_scheduler2 = StepLR(optimizer, step_size=3, gamma=0.01)
scheduler2 = LRScheduler(torch_scheduler2, use_legacy=False)

scheduler = ConcatScheduler(schedulers=[scheduler1, scheduler2], durations=[4, ])

def dummy_update(engine, batch):
    pass
trainer = Engine(dummy_update)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr"])
    
_ = trainer.run([0] * 9, max_epochs=1)
--- scheduler 1 
0.01
0.01
0.01
0.001
--- scheduler 2 
0.001
0.001
0.001
1e-05
1e-05

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 11, 2022

@yuta0821 sorry for delay on this PR, i'm still missing a complete understanding of the problem here. I'll try to figure out what happens exactly on my side and try to suggest a way to handle that. Adding an argument for internal usage is not a good design.
Thanks for your patience and working on the PR !

@yuta0821
Copy link
Contributor Author

@vfdev-5

@yuta0821 sorry for delay on this PR, i'm still missing a complete understanding of the problem here. I'll try to figure out what happens exactly on my side and try to suggest a way to handle that. Adding an argument for internal usage is not a good design.
Thanks for your patience and working on the PR !

No, sir, not at all. Thanks a lot for your help. If I can contribute in any way, I'll be happy to do so.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 12, 2022

@yuta0821 i understand the problem and why you need to introduce keep_first_lr argument.
In two words for others, the problem with create_lr_scheduler_with_warmup is that either 1) we have an additional transitional lr value between PiecewiseLinear and LRScheduler:

1 0.0
- PiecewiseLinear updates LR
2 0.03
- PiecewiseLinear updates LR
3 0.06
- PiecewiseLinear updates LR
4 0.09
- PiecewiseLinear updates LR
5 0.12
- PiecewiseLinear updates LR  <----- additional step to remove
6 0.01
- LRScheduler updates LR 
7 0.01
- LRScheduler updates LR
8 0.01
- LRScheduler updates LR
9 0.01
- LRScheduler updates LR
10 0.005

or 2) LRScheduler modifies LR in a relative way => if PiecewiseLinear stops one step before, then LRScheduler will work from this value, e.g 0.12 from example above. This also gives a wrong behavior.

I see here two ways to move forward:

  • either we add a more meaningful public arg to LRScheduler, e.g. relative: bool = True and thus we setup create_lr_scheduler_with_warmup with LRScheduler(..., relative=False)
  • do something like lr_scheduler.lr_scheduler.last_epoch += 1 inside create_lr_scheduler_with_warmup:
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
lr_scheduler.lr_scheduler.last_epoch += 1

@yuta0821
Copy link
Contributor Author

@yuta0821 i understand the problem and why you need to introduce keep_first_lr argument.

Thank you for your analysis. Your summary is exactly what I was thinking.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 22, 2022

@yuta0821 I'm trying to figure out how to merge your work to master step by step.
You updated tests/ignite/handlers/test_param_scheduler.py, do you think we could make merge without updates introduced in ignite/handlers/param_scheduler.py ?
Same question for tests/ignite/contrib/engines/test_common.py...
If yes, could you please send a PR with only updated tests ? Thanks a lot for your patience

@yuta0821
Copy link
Contributor Author

yuta0821 commented Mar 22, 2022

@vfdev-5 Thank you for your reply !
I think we could make merge neither tests/ignite/handlers/test_param_scheduler.py nor tests/ignite/contrib/engines/test_common.py without updates introduced ignite/handlers/param_scheduler.py.
This is because the changes in tests such like trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) are added to tackle on changes caused by ignite/handlers/param_scheduler.py.

@vfdev-5 vfdev-5 force-pushed the feature/#2465_LRScheduler_attach_Events branch from 52dcbcf to f593385 Compare April 19, 2022 17:26
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 19, 2022

@yuta0821 I pushed few commits to make this PR merged. Basically, it is one of your suggestions. Currently, I do not have a better way to fix it. Let's keep it as it is.
Thanks again for starting this work and sorry for delay it took to make it landed

@vfdev-5 vfdev-5 enabled auto-merge (squash) April 19, 2022 17:41
@vfdev-5 vfdev-5 merged commit 545d125 into pytorch:master Apr 19, 2022
@yuta0821
Copy link
Contributor Author

Thank you for your review.

I pushed few commits to make this PR merged. Basically, it is one of your suggestions. Currently, I do not have a better way to fix it. Let's keep it as it is.

I understand. If I come up with a better way to fix it, I will create a new issue.

@vfdev-5 vfdev-5 changed the title Feature/#2465 lr scheduler attach events [BC-breaking] Feature/#2465 lr scheduler attach events May 3, 2022
@yuta0821 yuta0821 deleted the feature/#2465_LRScheduler_attach_Events branch October 28, 2022 13:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: contrib Contrib module module: handlers Core Handlers module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make LRScheduler attachable to Events.ITERATION_STARTED

3 participants