Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 #149312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Mar 17, 2025

Fixes #102261

Changes

  • Use flag _is_initial to replace self.last_epoch == 0 condition to judge whether lr should be initial value
  • Add test for ExponentialLR checkpoint usecase

Test Result

pytest -s test/optim/test_lrscheduler.py  -vv

image

Copy link

pytorch-bot bot commented Mar 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149312

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 538d5e0 with merge base 01f226b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong zeshengzong marked this pull request as ready for review March 18, 2025 08:06
@zeshengzong
Copy link
Contributor Author

Hello @albanD @janeyx99 , please check whether the fixing is feasible, if it works, I would like to continue fix more schedulers which have same problem, like MultiplicativeLR, LinearLR, thanks!

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 20, 2025
@albanD albanD removed their request for review April 9, 2025 19:37
@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix/optim/step onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout fix/optim/step && git pull --rebase)

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look like the right approach. If the discrepancy is for ExponentialLR between get_lr and _get_closed_form_lr, I'd expect the fix to be local there. Could you explain your approach a little bit?

optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
sch2 = LRClass(optim2, last_epoch=1)
self.assertEqual(optim.param_groups[0]["lr"], optim2.param_groups[0]["lr"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the same comparison as the repro--we should be comparing that the closed form lr is the same as the params group lr?

@@ -724,7 +738,7 @@ def get_lr(self):
"""Compute the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)

if self.last_epoch == 0:
if self._is_initial:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._is_initial:
// when loading from a checkpoint, we don't want _initial_step (called from the constructor) to update the lr
// one more step ahead of itself.
if self._is_initial:

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually, I see what you're doing now. Sorry I was confused yesterday. I'm willing to accept this fix if you update the test case.

It would also be good to include a comment about why we prefer the _is_initial.

@janeyx99 janeyx99 added the topic: bug fixes topic category label May 6, 2025
@janeyx99 janeyx99 dismissed their stale review May 6, 2025 17:46

left newer review

@@ -134,7 +135,8 @@ def wrapper(*args, **kwargs):
def _initial_step(self):
"""Initialize step counts and perform a step."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As someone who has looked into LRScheduler more than I've been able to, have you seen a good reason why we need to call .step() from the constructor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: optim topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ExponentialLR unexpectedly calls step() when init argument last_epoch is larger than -1
4 participants