Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e4701b5

Browse files
authored
updata the way about how to load pretrain model
1 parent 285c260 commit e4701b5

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

implementations/esrgan/esrgan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@
7272

7373
if opt.epoch != 0:
7474
# Load pretrained models
75-
generator.load_state_dict(torch.load("saved_models/generator_%d.pth"))
76-
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth"))
75+
print("loading pretrain model")
76+
generator.load_state_dict(torch.load("saved_models/generator_%d.pth"%opt.epoch))
77+
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth"%opt.epoch))
7778

7879
# Optimizers
7980
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
@@ -191,5 +192,5 @@
191192

192193
if batches_done % opt.checkpoint_interval == 0:
193194
# Save model checkpoints
194-
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % batches_done)
195-
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % batches_done)
195+
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
196+
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" %epoch)

0 commit comments

Comments
 (0)