Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 92748d9

Browse files
committed
modify MMD and JMMD losses
1 parent 2dfdba2 commit 92748d9

File tree

1 file changed

+47
-15
lines changed

1 file changed

+47
-15
lines changed

pytorch/src/loss.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,30 @@ def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
2929
batch_size = int(source.size()[0])
3030
kernels = guassian_kernel(source, target,
3131
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
32-
loss = 0
33-
for i in range(batch_size):
34-
s1, s2 = i, (i+1)%batch_size
35-
t1, t2 = s1+batch_size, s2+batch_size
36-
loss += kernels[s1, s2] + kernels[t1, t2]
37-
loss -= kernels[s1, t2] + kernels[s2, t1]
38-
return loss / float(batch_size)
32+
33+
# Linear version
34+
# loss = 0
35+
# for i in range(batch_size):
36+
# s1, s2 = i, (i+1)%batch_size
37+
# t1, t2 = s1+batch_size, s2+batch_size
38+
# loss += kernels[s1, s2] + kernels[t1, t2]
39+
# loss -= kernels[s1, t2] + kernels[s2, t1]
40+
# return loss / float(batch_size)
41+
42+
loss1 = 0
43+
for s1 in range(batch_size):
44+
for s2 in range(s1+1, batch_size):
45+
t1, t2 = s1+batch_size, s2+batch_size
46+
loss1 += kernels[s1, s2] + kernels[t1, t2]
47+
loss1 = loss1 / float(batch_size * (batch_size - 1) / 2)
48+
49+
loss2 = 0
50+
for s1 in range(batch_size):
51+
for s2 in range(batch_size):
52+
t1, t2 = s1+batch_size, s2+batch_size
53+
loss2 -= kernels[s1, t2] + kernels[s2, t1]
54+
loss2 = loss2 / float(batch_size * batch_size)
55+
return loss1 + loss2
3956

4057
def RTN():
4158
pass
@@ -57,15 +74,30 @@ def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fi
5774
joint_kernels = joint_kernels * kernels
5875
else:
5976
joint_kernels = kernels
60-
61-
loss = 0
62-
for i in range(batch_size):
63-
s1, s2 = i, (i+1)%batch_size
64-
t1, t2 = s1+batch_size, s2+batch_size
65-
loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
66-
loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
67-
return loss / float(batch_size)
6877

78+
# Linear version
79+
# loss = 0
80+
# for i in range(batch_size):
81+
# s1, s2 = i, (i+1)%batch_size
82+
# t1, t2 = s1+batch_size, s2+batch_size
83+
# loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
84+
# loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
85+
# return loss / float(batch_size)
86+
87+
loss1 = 0
88+
for s1 in range(batch_size):
89+
for s2 in range(s1 + 1, batch_size):
90+
t1, t2 = s1 + batch_size, s2 + batch_size
91+
loss1 += joint_kernels[s1, s2] + joint_kernels[t1, t2]
92+
loss1 = loss1 / float(batch_size * (batch_size - 1) / 2)
93+
94+
loss2 = 0
95+
for s1 in range(batch_size):
96+
for s2 in range(batch_size):
97+
t1, t2 = s1 + batch_size, s2 + batch_size
98+
loss2 -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
99+
loss2 = loss2 / float(batch_size * batch_size)
100+
return loss1 + loss2
69101

70102

71103
loss_dict = {"DAN":DAN, "RTN":RTN, "JAN":JAN}

0 commit comments

Comments
 (0)