Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Weight decay causes loaded model to not match saved one #1201

@danielhers

Description

@danielhers

This has caused me a lot of frustration until I finally figured out why my saved models' results don't match when I load them.
After training a model and saving it, I expect it to produce exactly the same results as just before it was saved (assuming no updates were done in between, of course). However, this is not the case when using weight decay. Looks like the weight decay does not apply to the loaded model, even though it is set globally in dynet_config.

Minimal working example:

import dynet_config
dynet_config.set(weight_decay=1e-5)
import dynet as dy

m1 = dy.ParameterCollection()
p1 = m1.add_parameters(1)
t = dy.SimpleSGDTrainer(m1)
p1.expr().forward()
p1.expr().backward()
t.update()
dy.renew_cg()
v1 = p1.expr().value()
dy.save("test", [p1])
m2 = dy.ParameterCollection()
[p2] = dy.load("test", m2)
v2 = p2.expr().value()
assert v1 == v2, "%s != %s" % (v1, v2)
# >>> AssertionError: -1.5506035089492798 != -1.5506190061569214

Changing weight_decay to 0 fixes the problem.

Related to #917.

Metadata

Metadata

Assignees

Labels

major bugIssues that silently cause incorrect results, break installation on common environments, etc.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions