Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bf60abf

Browse files
psycharomartinwicke
authored andcommitted
Spatial transformer: (tensorflow#57)
* Modified the way the output size is specified. * Added support for batches of inputs.
1 parent 8332400 commit bf60abf

File tree

1 file changed

+36
-18
lines changed

1 file changed

+36
-18
lines changed

transformer/spatial_transformer.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
import tensorflow as tf
1616

17-
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
17+
def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
1818
"""Spatial Transformer Layer
1919
2020
Implements a spatial transformer layer as described in [1]_.
@@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
2828
theta: float
2929
The output of the
3030
localisation network should be [num_batch, 6].
31-
downsample_factor : float
32-
A value of 1 will keep the original size of the image
33-
Values larger than 1 will downsample the image.
34-
Values below 1 will upsample the image
35-
example image: height = 100, width = 200
36-
downsample_factor = 2
37-
output image will then be 50, 100
38-
31+
out_size: tuple of two floats
32+
The size of the output of the network
33+
3934
References
4035
----------
4136
.. [1] Spatial Transformer Networks
@@ -61,7 +56,7 @@ def _repeat(x, n_repeats):
6156
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
6257
return tf.reshape(x,[-1])
6358

64-
def _interpolate(im, x, y, downsample_factor):
59+
def _interpolate(im, x, y, out_size):
6560
with tf.variable_scope('_interpolate'):
6661
# constants
6762
num_batch = tf.shape(im)[0]
@@ -73,8 +68,8 @@ def _interpolate(im, x, y, downsample_factor):
7368
y = tf.cast(y, 'float32')
7469
height_f = tf.cast(height, 'float32')
7570
width_f = tf.cast(width, 'float32')
76-
out_height = tf.cast(height_f // downsample_factor, 'int32')
77-
out_width = tf.cast(width_f // downsample_factor, 'int32')
71+
out_height = out_size[0]
72+
out_width = out_size[1]
7873
zero = tf.zeros([], dtype='int32')
7974
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
8075
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
@@ -142,7 +137,7 @@ def _meshgrid(height, width):
142137
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
143138
return grid
144139

145-
def _transform(theta, input_dim, downsample_factor):
140+
def _transform(theta, input_dim, out_size):
146141
with tf.variable_scope('_transform'):
147142
num_batch = tf.shape(input_dim)[0]
148143
height = tf.shape(input_dim)[1]
@@ -154,8 +149,8 @@ def _transform(theta, input_dim, downsample_factor):
154149
# grid of (x_t, y_t, 1), eq (1) in ref [1]
155150
height_f = tf.cast(height, 'float32')
156151
width_f = tf.cast(width, 'float32')
157-
out_height = tf.cast(height_f // downsample_factor, 'int32')
158-
out_width = tf.cast(width_f // downsample_factor, 'int32')
152+
out_height = out_size[0]
153+
out_width = out_size[1]
159154
grid = _meshgrid(out_height, out_width)
160155
grid = tf.expand_dims(grid,0)
161156
grid = tf.reshape(grid,[-1])
@@ -171,11 +166,34 @@ def _transform(theta, input_dim, downsample_factor):
171166

172167
input_transformed = _interpolate(
173168
input_dim, x_s_flat, y_s_flat,
174-
downsample_factor)
169+
out_size)
175170

176171
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
177172
return output
178173

179174
with tf.variable_scope(name):
180-
output = _transform(theta, U, downsample_factor)
181-
return output
175+
output = _transform(theta, U, out_size)
176+
return output
177+
178+
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
179+
"""Batch Spatial Transformer Layer
180+
181+
Parameters
182+
----------
183+
184+
U : float
185+
tensor of inputs [num_batch,height,width,num_channels]
186+
thetas : float
187+
a set of transformations for each input [num_batch,num_transforms,6]
188+
out_size : int
189+
the size of the output [out_height,out_width]
190+
191+
Returns: float
192+
Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
193+
"""
194+
with tf.variable_scope(name):
195+
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
196+
indices = [[i]*num_transforms for i in xrange(num_batch)]
197+
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
198+
return transformer(input_repeated, thetas, out_size)
199+

0 commit comments

Comments
 (0)