|
46 | 46 | import matplotlib.pyplot as plt |
47 | 47 | import time |
48 | 48 | import os |
49 | | -import copy |
| 49 | +from tempfile import TemporaryDirectory |
50 | 50 |
|
51 | 51 | cudnn.benchmark = True |
52 | 52 | plt.ion() # interactive mode |
@@ -146,67 +146,71 @@ def imshow(inp, title=None): |
146 | 146 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25): |
147 | 147 | since = time.time() |
148 | 148 |
|
149 | | - best_model_wts = copy.deepcopy(model.state_dict()) |
150 | | - best_acc = 0.0 |
151 | | - |
152 | | - for epoch in range(num_epochs): |
153 | | - print(f'Epoch {epoch}/{num_epochs - 1}') |
154 | | - print('-' * 10) |
155 | | - |
156 | | - # Each epoch has a training and validation phase |
157 | | - for phase in ['train', 'val']: |
158 | | - if phase == 'train': |
159 | | - model.train() # Set model to training mode |
160 | | - else: |
161 | | - model.eval() # Set model to evaluate mode |
162 | | - |
163 | | - running_loss = 0.0 |
164 | | - running_corrects = 0 |
165 | | - |
166 | | - # Iterate over data. |
167 | | - for inputs, labels in dataloaders[phase]: |
168 | | - inputs = inputs.to(device) |
169 | | - labels = labels.to(device) |
170 | | - |
171 | | - # zero the parameter gradients |
172 | | - optimizer.zero_grad() |
173 | | - |
174 | | - # forward |
175 | | - # track history if only in train |
176 | | - with torch.set_grad_enabled(phase == 'train'): |
177 | | - outputs = model(inputs) |
178 | | - _, preds = torch.max(outputs, 1) |
179 | | - loss = criterion(outputs, labels) |
180 | | - |
181 | | - # backward + optimize only if in training phase |
182 | | - if phase == 'train': |
183 | | - loss.backward() |
184 | | - optimizer.step() |
185 | | - |
186 | | - # statistics |
187 | | - running_loss += loss.item() * inputs.size(0) |
188 | | - running_corrects += torch.sum(preds == labels.data) |
189 | | - if phase == 'train': |
190 | | - scheduler.step() |
191 | | - |
192 | | - epoch_loss = running_loss / dataset_sizes[phase] |
193 | | - epoch_acc = running_corrects.double() / dataset_sizes[phase] |
194 | | - |
195 | | - print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
196 | | - |
197 | | - # deep copy the model |
198 | | - if phase == 'val' and epoch_acc > best_acc: |
199 | | - best_acc = epoch_acc |
200 | | - best_model_wts = copy.deepcopy(model.state_dict()) |
201 | | - |
202 | | - print() |
203 | | - |
204 | | - time_elapsed = time.time() - since |
205 | | - print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') |
206 | | - print(f'Best val Acc: {best_acc:4f}') |
207 | | - |
208 | | - # load best model weights |
209 | | - model.load_state_dict(best_model_wts) |
| 149 | + # Create a temporary directory to save training checkpoints |
| 150 | + with TemporaryDirectory() as tempdir: |
| 151 | + best_model_params_path = os.path.join(tempdir, 'best_model_params.pt') |
| 152 | + |
| 153 | + torch.save(model.state_dict(), best_model_params_path) |
| 154 | + best_acc = 0.0 |
| 155 | + |
| 156 | + for epoch in range(num_epochs): |
| 157 | + print(f'Epoch {epoch}/{num_epochs - 1}') |
| 158 | + print('-' * 10) |
| 159 | + |
| 160 | + # Each epoch has a training and validation phase |
| 161 | + for phase in ['train', 'val']: |
| 162 | + if phase == 'train': |
| 163 | + model.train() # Set model to training mode |
| 164 | + else: |
| 165 | + model.eval() # Set model to evaluate mode |
| 166 | + |
| 167 | + running_loss = 0.0 |
| 168 | + running_corrects = 0 |
| 169 | + |
| 170 | + # Iterate over data. |
| 171 | + for inputs, labels in dataloaders[phase]: |
| 172 | + inputs = inputs.to(device) |
| 173 | + labels = labels.to(device) |
| 174 | + |
| 175 | + # zero the parameter gradients |
| 176 | + optimizer.zero_grad() |
| 177 | + |
| 178 | + # forward |
| 179 | + # track history if only in train |
| 180 | + with torch.set_grad_enabled(phase == 'train'): |
| 181 | + outputs = model(inputs) |
| 182 | + _, preds = torch.max(outputs, 1) |
| 183 | + loss = criterion(outputs, labels) |
| 184 | + |
| 185 | + # backward + optimize only if in training phase |
| 186 | + if phase == 'train': |
| 187 | + loss.backward() |
| 188 | + optimizer.step() |
| 189 | + |
| 190 | + # statistics |
| 191 | + running_loss += loss.item() * inputs.size(0) |
| 192 | + running_corrects += torch.sum(preds == labels.data) |
| 193 | + if phase == 'train': |
| 194 | + scheduler.step() |
| 195 | + |
| 196 | + epoch_loss = running_loss / dataset_sizes[phase] |
| 197 | + epoch_acc = running_corrects.double() / dataset_sizes[phase] |
| 198 | + |
| 199 | + print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
| 200 | + |
| 201 | + # deep copy the model |
| 202 | + if phase == 'val' and epoch_acc > best_acc: |
| 203 | + best_acc = epoch_acc |
| 204 | + torch.save(model.state_dict(), best_model_params_path) |
| 205 | + |
| 206 | + print() |
| 207 | + |
| 208 | + time_elapsed = time.time() - since |
| 209 | + print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') |
| 210 | + print(f'Best val Acc: {best_acc:4f}') |
| 211 | + |
| 212 | + # load best model weights |
| 213 | + model.load_state_dict(torch.load(best_model_params_path)) |
210 | 214 | return model |
211 | 215 |
|
212 | 216 |
|
|
0 commit comments