Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d173914

Browse files
committed
added 15_tranfer_learning
1 parent 363aa89 commit d173914

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed

15_transfer_learning.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.optim import lr_scheduler
5+
import numpy as np
6+
import torchvision
7+
from torchvision import datasets, models, transforms
8+
import matplotlib.pyplot as plt
9+
import time
10+
import os
11+
import copy
12+
13+
mean = np.array([0.5, 0.5, 0.5])
14+
std = np.array([0.25, 0.25, 0.25])
15+
16+
data_transforms = {
17+
'train': transforms.Compose([
18+
transforms.RandomResizedCrop(224),
19+
transforms.RandomHorizontalFlip(),
20+
transforms.ToTensor(),
21+
transforms.Normalize(mean, std)
22+
]),
23+
'val': transforms.Compose([
24+
transforms.Resize(256),
25+
transforms.CenterCrop(224),
26+
transforms.ToTensor(),
27+
transforms.Normalize(mean, std)
28+
]),
29+
}
30+
31+
data_dir = 'data/hymenoptera_data'
32+
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
33+
data_transforms[x])
34+
for x in ['train', 'val']}
35+
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
36+
shuffle=True, num_workers=0)
37+
for x in ['train', 'val']}
38+
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
39+
class_names = image_datasets['train'].classes
40+
41+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42+
print(class_names)
43+
44+
45+
def imshow(inp, title):
46+
"""Imshow for Tensor."""
47+
inp = inp.numpy().transpose((1, 2, 0))
48+
inp = std * inp + mean
49+
inp = np.clip(inp, 0, 1)
50+
plt.imshow(inp)
51+
plt.title(title)
52+
plt.show()
53+
54+
55+
# Get a batch of training data
56+
inputs, classes = next(iter(dataloaders['train']))
57+
58+
# Make a grid from batch
59+
out = torchvision.utils.make_grid(inputs)
60+
61+
imshow(out, title=[class_names[x] for x in classes])
62+
63+
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
64+
since = time.time()
65+
66+
best_model_wts = copy.deepcopy(model.state_dict())
67+
best_acc = 0.0
68+
69+
for epoch in range(num_epochs):
70+
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
71+
print('-' * 10)
72+
73+
# Each epoch has a training and validation phase
74+
for phase in ['train', 'val']:
75+
if phase == 'train':
76+
model.train() # Set model to training mode
77+
else:
78+
model.eval() # Set model to evaluate mode
79+
80+
running_loss = 0.0
81+
running_corrects = 0
82+
83+
# Iterate over data.
84+
for inputs, labels in dataloaders[phase]:
85+
inputs = inputs.to(device)
86+
labels = labels.to(device)
87+
88+
# forward
89+
# track history if only in train
90+
with torch.set_grad_enabled(phase == 'train'):
91+
outputs = model(inputs)
92+
_, preds = torch.max(outputs, 1)
93+
loss = criterion(outputs, labels)
94+
95+
# backward + optimize only if in training phase
96+
if phase == 'train':
97+
optimizer.zero_grad()
98+
loss.backward()
99+
optimizer.step()
100+
101+
# statistics
102+
running_loss += loss.item() * inputs.size(0)
103+
running_corrects += torch.sum(preds == labels.data)
104+
105+
if phase == 'train':
106+
scheduler.step()
107+
108+
epoch_loss = running_loss / dataset_sizes[phase]
109+
epoch_acc = running_corrects.double() / dataset_sizes[phase]
110+
111+
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
112+
phase, epoch_loss, epoch_acc))
113+
114+
# deep copy the model
115+
if phase == 'val' and epoch_acc > best_acc:
116+
best_acc = epoch_acc
117+
best_model_wts = copy.deepcopy(model.state_dict())
118+
119+
print()
120+
121+
time_elapsed = time.time() - since
122+
print('Training complete in {:.0f}m {:.0f}s'.format(
123+
time_elapsed // 60, time_elapsed % 60))
124+
print('Best val Acc: {:4f}'.format(best_acc))
125+
126+
# load best model weights
127+
model.load_state_dict(best_model_wts)
128+
return model
129+
130+
131+
#### Finetuning the convnet ####
132+
# Load a pretrained model and reset final fully connected layer.
133+
134+
model = models.resnet18(pretrained=True)
135+
num_ftrs = model.fc.in_features
136+
# Here the size of each output sample is set to 2.
137+
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
138+
model.fc = nn.Linear(num_ftrs, 2)
139+
140+
model = model.to(device)
141+
142+
criterion = nn.CrossEntropyLoss()
143+
144+
# Observe that all parameters are being optimized
145+
optimizer = optim.SGD(model.parameters(), lr=0.001)
146+
147+
# StepLR Decays the learning rate of each parameter group by gamma every step_size epochs
148+
# Decay LR by a factor of 0.1 every 7 epochs
149+
# Learning rate scheduling should be applied after optimizer’s update
150+
# e.g., you should write your code this way:
151+
# for epoch in range(100):
152+
# train(...)
153+
# validate(...)
154+
# scheduler.step()
155+
156+
step_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
157+
158+
model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=25)
159+
160+
161+
#### ConvNet as fixed feature extractor ####
162+
# Here, we need to freeze all the network except the final layer.
163+
# We need to set requires_grad == False to freeze the parameters so that the gradients are not computed in backward()
164+
model_conv = torchvision.models.resnet18(pretrained=True)
165+
for param in model_conv.parameters():
166+
param.requires_grad = False
167+
168+
# Parameters of newly constructed modules have requires_grad=True by default
169+
num_ftrs = model_conv.fc.in_features
170+
model_conv.fc = nn.Linear(num_ftrs, 2)
171+
172+
model_conv = model_conv.to(device)
173+
174+
criterion = nn.CrossEntropyLoss()
175+
176+
# Observe that only parameters of final layer are being optimized as
177+
# opposed to before.
178+
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
179+
180+
# Decay LR by a factor of 0.1 every 7 epochs
181+
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
182+
183+
model_conv = train_model(model_conv, criterion, optimizer_conv,
184+
exp_lr_scheduler, num_epochs=25)

0 commit comments

Comments
 (0)