@@ -29,13 +29,30 @@ def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
2929 batch_size = int (source .size ()[0 ])
3030 kernels = guassian_kernel (source , target ,
3131 kernel_mul = kernel_mul , kernel_num = kernel_num , fix_sigma = fix_sigma )
32- loss = 0
33- for i in range (batch_size ):
34- s1 , s2 = i , (i + 1 )% batch_size
35- t1 , t2 = s1 + batch_size , s2 + batch_size
36- loss += kernels [s1 , s2 ] + kernels [t1 , t2 ]
37- loss -= kernels [s1 , t2 ] + kernels [s2 , t1 ]
38- return loss / float (batch_size )
32+
33+ # Linear version
34+ # loss = 0
35+ # for i in range(batch_size):
36+ # s1, s2 = i, (i+1)%batch_size
37+ # t1, t2 = s1+batch_size, s2+batch_size
38+ # loss += kernels[s1, s2] + kernels[t1, t2]
39+ # loss -= kernels[s1, t2] + kernels[s2, t1]
40+ # return loss / float(batch_size)
41+
42+ loss1 = 0
43+ for s1 in range (batch_size ):
44+ for s2 in range (s1 + 1 , batch_size ):
45+ t1 , t2 = s1 + batch_size , s2 + batch_size
46+ loss1 += kernels [s1 , s2 ] + kernels [t1 , t2 ]
47+ loss1 = loss1 / float (batch_size * (batch_size - 1 ) / 2 )
48+
49+ loss2 = 0
50+ for s1 in range (batch_size ):
51+ for s2 in range (batch_size ):
52+ t1 , t2 = s1 + batch_size , s2 + batch_size
53+ loss2 -= kernels [s1 , t2 ] + kernels [s2 , t1 ]
54+ loss2 = loss2 / float (batch_size * batch_size )
55+ return loss1 + loss2
3956
4057def RTN ():
4158 pass
@@ -57,15 +74,30 @@ def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fi
5774 joint_kernels = joint_kernels * kernels
5875 else :
5976 joint_kernels = kernels
60-
61- loss = 0
62- for i in range (batch_size ):
63- s1 , s2 = i , (i + 1 )% batch_size
64- t1 , t2 = s1 + batch_size , s2 + batch_size
65- loss += joint_kernels [s1 , s2 ] + joint_kernels [t1 , t2 ]
66- loss -= joint_kernels [s1 , t2 ] + joint_kernels [s2 , t1 ]
67- return loss / float (batch_size )
6877
78+ # Linear version
79+ # loss = 0
80+ # for i in range(batch_size):
81+ # s1, s2 = i, (i+1)%batch_size
82+ # t1, t2 = s1+batch_size, s2+batch_size
83+ # loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
84+ # loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
85+ # return loss / float(batch_size)
86+
87+ loss1 = 0
88+ for s1 in range (batch_size ):
89+ for s2 in range (s1 + 1 , batch_size ):
90+ t1 , t2 = s1 + batch_size , s2 + batch_size
91+ loss1 += joint_kernels [s1 , s2 ] + joint_kernels [t1 , t2 ]
92+ loss1 = loss1 / float (batch_size * (batch_size - 1 ) / 2 )
93+
94+ loss2 = 0
95+ for s1 in range (batch_size ):
96+ for s2 in range (batch_size ):
97+ t1 , t2 = s1 + batch_size , s2 + batch_size
98+ loss2 -= joint_kernels [s1 , t2 ] + joint_kernels [s2 , t1 ]
99+ loss2 = loss2 / float (batch_size * batch_size )
100+ return loss1 + loss2
69101
70102
71103loss_dict = {"DAN" :DAN , "RTN" :RTN , "JAN" :JAN }
0 commit comments