@@ -30,15 +30,6 @@ def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
3030 kernels = guassian_kernel (source , target ,
3131 kernel_mul = kernel_mul , kernel_num = kernel_num , fix_sigma = fix_sigma )
3232
33- # Linear version
34- # loss = 0
35- # for i in range(batch_size):
36- # s1, s2 = i, (i+1)%batch_size
37- # t1, t2 = s1+batch_size, s2+batch_size
38- # loss += kernels[s1, s2] + kernels[t1, t2]
39- # loss -= kernels[s1, t2] + kernels[s2, t1]
40- # return loss / float(batch_size)
41-
4233 loss1 = 0
4334 for s1 in range (batch_size ):
4435 for s2 in range (s1 + 1 , batch_size ):
@@ -54,6 +45,21 @@ def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
5445 loss2 = loss2 / float (batch_size * batch_size )
5546 return loss1 + loss2
5647
48+ def DAN_Linear (source , target , kernel_mul = 2.0 , kernel_num = 5 , fix_sigma = None ):
49+ batch_size = int (source .size ()[0 ])
50+ kernels = guassian_kernel (source , target ,
51+ kernel_mul = kernel_mul , kernel_num = kernel_num , fix_sigma = fix_sigma )
52+
53+ # Linear version
54+ loss = 0
55+ for i in range (batch_size ):
56+ s1 , s2 = i , (i + 1 )% batch_size
57+ t1 , t2 = s1 + batch_size , s2 + batch_size
58+ loss += kernels [s1 , s2 ] + kernels [t1 , t2 ]
59+ loss -= kernels [s1 , t2 ] + kernels [s2 , t1 ]
60+ return loss / float (batch_size )
61+
62+
5763def RTN ():
5864 pass
5965
@@ -75,15 +81,6 @@ def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fi
7581 else :
7682 joint_kernels = kernels
7783
78- # Linear version
79- # loss = 0
80- # for i in range(batch_size):
81- # s1, s2 = i, (i+1)%batch_size
82- # t1, t2 = s1+batch_size, s2+batch_size
83- # loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
84- # loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
85- # return loss / float(batch_size)
86-
8784 loss1 = 0
8885 for s1 in range (batch_size ):
8986 for s2 in range (s1 + 1 , batch_size ):
@@ -99,5 +96,31 @@ def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fi
9996 loss2 = loss2 / float (batch_size * batch_size )
10097 return loss1 + loss2
10198
99+ def JAN_Linear (source_list , target_list , kernel_muls = [2.0 , 2.0 ], kernel_nums = [5 , 1 ], fix_sigma_list = [None , 1.68 ]):
100+ batch_size = int (source_list [0 ].size ()[0 ])
101+ layer_num = len (source_list )
102+ joint_kernels = None
103+ for i in range (layer_num ):
104+ source = source_list [i ]
105+ target = target_list [i ]
106+ kernel_mul = kernel_muls [i ]
107+ kernel_num = kernel_nums [i ]
108+ fix_sigma = fix_sigma_list [i ]
109+ kernels = guassian_kernel (source , target ,
110+ kernel_mul = kernel_mul , kernel_num = kernel_num , fix_sigma = fix_sigma )
111+ if joint_kernels is not None :
112+ joint_kernels = joint_kernels * kernels
113+ else :
114+ joint_kernels = kernels
115+
116+ # Linear version
117+ loss = 0
118+ for i in range (batch_size ):
119+ s1 , s2 = i , (i + 1 )% batch_size
120+ t1 , t2 = s1 + batch_size , s2 + batch_size
121+ loss += joint_kernels [s1 , s2 ] + joint_kernels [t1 , t2 ]
122+ loss -= joint_kernels [s1 , t2 ] + joint_kernels [s2 , t1 ]
123+ return loss / float (batch_size )
124+
102125
103- loss_dict = {"DAN" :DAN , "RTN" :RTN , "JAN" :JAN }
126+ loss_dict = {"DAN" :DAN , "DAN_Linear" : DAN_Linear , " RTN" :RTN , "JAN" :JAN , "JAN_Linear" : JAN_Linear }
0 commit comments