model = torch.nn.Linear(1, 1)
optimizer = mechanize(torch.optim.AdamW)(model.parameters(), lr=1e-5)
x = torch.ones([5, 1])
out = torch.sum(model(x))
out.backward()
optimizer.step()
print('done first step')
new_optimizer = mechanize(torch.optim.AdamW)(model.parameters(), lr=1e-5)
new_optimizer.load_state_dict(optimizer.state_dict())
out = torch.sum(model(x))
out.backward()
new_optimizer.step()
print('done new steps using new optimizer loaded')
The following simple script reproduces the issue: