forked from ClementPinard/FlowNetPytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultiscaleloss.py
More file actions
42 lines (35 loc) · 1.53 KB
/
Copy pathmultiscaleloss.py
File metadata and controls
42 lines (35 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import math
class MultiScaleLoss(nn.Module):
def __init__(self, scales, downscale, weights=None, loss= 'Abs'):
super(MultiScaleLoss,self).__init__()
self.downscale = downscale
self.weights = torch.Tensor(scales).fill_(1) if weights is None else torch.Tensor(weights)
assert(len(weights) == scales)
if type(loss) is str:
assert(loss in ['L1','MSE','SmoothL1'])
if loss == 'L1':
self.loss = nn.L1Loss()
elif loss == 'MSE':
self.loss = nn.MSELoss()
elif loss == 'SmoothL1':
self.loss = nn.SmoothL1Loss()
else:
self.loss = loss
self.multiScales = [nn.AvgPool2d(self.downscale*(2**i), self.downscale*(2**i)) for i in range(scales)]
def forward(self, input, target):
if type(input) is tuple:
out = 0
for i,input_ in enumerate(input):
target_ = self.multiScales[i](target)
out += self.weights[i]*self.loss(input_,target_)
else:
out = self.loss(input,self.multiScales[0](target))
return out
def multiscaleloss(scales=5, downscale=4, weights=None, loss='L1'):
if weights is None:
weights = (0.005,0.01,0.02,0.08,0.32) #as in original article
if scales ==1 and type(weights) is not tuple: #a single value needs a particular syntax to be considered as a tuple
weights = (weights,)
return MultiScaleLoss(scales,downscale,weights,loss)