14
14
# ==============================================================================
15
15
import tensorflow as tf
16
16
17
- def transformer (U , theta , downsample_factor = 1 , name = 'SpatialTransformer' , ** kwargs ):
17
+ def transformer (U , theta , out_size , name = 'SpatialTransformer' , ** kwargs ):
18
18
"""Spatial Transformer Layer
19
19
20
20
Implements a spatial transformer layer as described in [1]_.
@@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
28
28
theta: float
29
29
The output of the
30
30
localisation network should be [num_batch, 6].
31
- downsample_factor : float
32
- A value of 1 will keep the original size of the image
33
- Values larger than 1 will downsample the image.
34
- Values below 1 will upsample the image
35
- example image: height = 100, width = 200
36
- downsample_factor = 2
37
- output image will then be 50, 100
38
-
31
+ out_size: tuple of two floats
32
+ The size of the output of the network
33
+
39
34
References
40
35
----------
41
36
.. [1] Spatial Transformer Networks
@@ -61,7 +56,7 @@ def _repeat(x, n_repeats):
61
56
x = tf .matmul (tf .reshape (x ,(- 1 , 1 )), rep )
62
57
return tf .reshape (x ,[- 1 ])
63
58
64
- def _interpolate (im , x , y , downsample_factor ):
59
+ def _interpolate (im , x , y , out_size ):
65
60
with tf .variable_scope ('_interpolate' ):
66
61
# constants
67
62
num_batch = tf .shape (im )[0 ]
@@ -73,8 +68,8 @@ def _interpolate(im, x, y, downsample_factor):
73
68
y = tf .cast (y , 'float32' )
74
69
height_f = tf .cast (height , 'float32' )
75
70
width_f = tf .cast (width , 'float32' )
76
- out_height = tf . cast ( height_f // downsample_factor , 'int32' )
77
- out_width = tf . cast ( width_f // downsample_factor , 'int32' )
71
+ out_height = out_size [ 0 ]
72
+ out_width = out_size [ 1 ]
78
73
zero = tf .zeros ([], dtype = 'int32' )
79
74
max_y = tf .cast (tf .shape (im )[1 ] - 1 , 'int32' )
80
75
max_x = tf .cast (tf .shape (im )[2 ] - 1 , 'int32' )
@@ -142,7 +137,7 @@ def _meshgrid(height, width):
142
137
grid = tf .concat (0 , [x_t_flat , y_t_flat , ones ])
143
138
return grid
144
139
145
- def _transform (theta , input_dim , downsample_factor ):
140
+ def _transform (theta , input_dim , out_size ):
146
141
with tf .variable_scope ('_transform' ):
147
142
num_batch = tf .shape (input_dim )[0 ]
148
143
height = tf .shape (input_dim )[1 ]
@@ -154,8 +149,8 @@ def _transform(theta, input_dim, downsample_factor):
154
149
# grid of (x_t, y_t, 1), eq (1) in ref [1]
155
150
height_f = tf .cast (height , 'float32' )
156
151
width_f = tf .cast (width , 'float32' )
157
- out_height = tf . cast ( height_f // downsample_factor , 'int32' )
158
- out_width = tf . cast ( width_f // downsample_factor , 'int32' )
152
+ out_height = out_size [ 0 ]
153
+ out_width = out_size [ 1 ]
159
154
grid = _meshgrid (out_height , out_width )
160
155
grid = tf .expand_dims (grid ,0 )
161
156
grid = tf .reshape (grid ,[- 1 ])
@@ -171,11 +166,34 @@ def _transform(theta, input_dim, downsample_factor):
171
166
172
167
input_transformed = _interpolate (
173
168
input_dim , x_s_flat , y_s_flat ,
174
- downsample_factor )
169
+ out_size )
175
170
176
171
output = tf .reshape (input_transformed , tf .pack ([num_batch , out_height , out_width , num_channels ]))
177
172
return output
178
173
179
174
with tf .variable_scope (name ):
180
- output = _transform (theta , U , downsample_factor )
181
- return output
175
+ output = _transform (theta , U , out_size )
176
+ return output
177
+
178
+ def batch_transformer (U , thetas , out_size , name = 'BatchSpatialTransformer' ):
179
+ """Batch Spatial Transformer Layer
180
+
181
+ Parameters
182
+ ----------
183
+
184
+ U : float
185
+ tensor of inputs [num_batch,height,width,num_channels]
186
+ thetas : float
187
+ a set of transformations for each input [num_batch,num_transforms,6]
188
+ out_size : int
189
+ the size of the output [out_height,out_width]
190
+
191
+ Returns: float
192
+ Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
193
+ """
194
+ with tf .variable_scope (name ):
195
+ num_batch , num_transforms = map (int , thetas .get_shape ().as_list ()[:2 ])
196
+ indices = [[i ]* num_transforms for i in xrange (num_batch )]
197
+ input_repeated = tf .gather (U , tf .reshape (indices , [- 1 ]))
198
+ return transformer (input_repeated , thetas , out_size )
199
+
0 commit comments