diff --git a/17_save_load.py b/17_save_load.py index 95716fa8..061bef81 100644 --- a/17_save_load.py +++ b/17_save_load.py @@ -81,7 +81,7 @@ def forward(self, x): torch.save(checkpoint, FILE) model = Model(n_input_features=6) -optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0) +optimizer = torch.optim.SGD(model.parameters(), lr=0) checkpoint = torch.load(FILE) model.load_state_dict(checkpoint['model_state'])