Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8bc9fe9

Browse files
committed
Updated to the latest version of TF-Slim
1 parent c74897b commit 8bc9fe9

File tree

11 files changed

+320
-27
lines changed

11 files changed

+320
-27
lines changed

inception/inception/slim/collections_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
2021
import tensorflow as tf
2122

2223
from inception.slim import slim

inception/inception/slim/inception_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from __future__ import division
4444
from __future__ import print_function
4545

46+
4647
import tensorflow as tf
4748

4849
from inception.slim import ops

inception/inception/slim/inception_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
2021
import tensorflow as tf
2122

2223
from inception.slim import inception_model as inception

inception/inception/slim/losses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import division
2727
from __future__ import print_function
2828

29+
2930
import tensorflow as tf
3031

3132
# In order to gather all losses in a network, the user should use this

inception/inception/slim/losses_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020

21+
2122
import tensorflow as tf
2223

2324
from inception.slim import losses

inception/inception/slim/ops.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from __future__ import print_function
2828

2929

30+
3031
import tensorflow as tf
3132

3233
from tensorflow.python.training import moving_averages
@@ -42,6 +43,7 @@
4243
@scopes.add_arg_scope
4344
def batch_norm(inputs,
4445
decay=0.999,
46+
center=True,
4547
scale=False,
4648
epsilon=0.001,
4749
moving_vars='moving_vars',
@@ -57,6 +59,7 @@ def batch_norm(inputs,
5759
inputs: a tensor of size [batch_size, height, width, channels]
5860
or [batch_size, channels].
5961
decay: decay for the moving average.
62+
center: If True, subtract beta. If False, beta is not created and ignored.
6063
scale: If True, multiply by gamma. If False, gamma is
6164
not used. When the next layer is linear (also e.g. ReLU), this can be
6265
disabled since the scaling can be done by the next layer.
@@ -78,31 +81,35 @@ def batch_norm(inputs,
7881
with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
7982
axis = list(range(len(inputs_shape) - 1))
8083
params_shape = inputs_shape[-1:]
81-
with scopes.arg_scope([variables.variable], restore=restore):
82-
# Allocate parameters for the beta and gamma of the normalization.
84+
# Allocate parameters for the beta and gamma of the normalization.
85+
beta, gamma = None, None
86+
if center:
8387
beta = variables.variable('beta',
8488
params_shape,
8589
initializer=tf.zeros_initializer,
86-
trainable=trainable)
87-
if scale:
88-
gamma = variables.variable('gamma',
89-
params_shape,
90-
initializer=tf.ones,
91-
trainable=trainable)
92-
else:
93-
gamma = None
94-
# Create moving_mean and moving_variance add them to moving_vars and
95-
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
96-
with scopes.arg_scope([variables.variable], trainable=False,
97-
collections=[
98-
moving_vars,
99-
tf.GraphKeys.MOVING_AVERAGE_VARIABLES]):
100-
moving_mean = variables.variable('moving_mean',
90+
trainable=trainable,
91+
restore=restore)
92+
if scale:
93+
gamma = variables.variable('gamma',
94+
params_shape,
95+
initializer=tf.ones_initializer,
96+
trainable=trainable,
97+
restore=restore)
98+
# Create moving_mean and moving_variance add them to
99+
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
100+
moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
101+
moving_mean = variables.variable('moving_mean',
102+
params_shape,
103+
initializer=tf.zeros_initializer,
104+
trainable=False,
105+
restore=restore,
106+
collections=moving_collections)
107+
moving_variance = variables.variable('moving_variance',
101108
params_shape,
102-
initializer=tf.zeros_initializer)
103-
moving_variance = variables.variable('moving_variance',
104-
params_shape,
105-
initializer=tf.ones)
109+
initializer=tf.ones_initializer,
110+
trainable=False,
111+
restore=restore,
112+
collections=moving_collections)
106113
if is_training:
107114
# Calculate the moments based on the individual batch.
108115
mean, variance = tf.nn.moments(inputs, axis)
@@ -400,7 +407,7 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
400407
401408
Args:
402409
inputs: the tensor to pass to the Dropout layer.
403-
keep_prob: the probability of dropping each input unit.
410+
keep_prob: the probability of keeping each input unit.
404411
is_training: whether or not the model is in training mode. If so, dropout is
405412
applied and values scaled. Otherwise, inputs is returned.
406413
scope: Optional scope for op_scope.

inception/inception/slim/ops_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020

21+
2122
import numpy as np
2223
import tensorflow as tf
2324

@@ -476,6 +477,20 @@ def testCreateOp(self):
476477
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
477478

478479
def testCreateVariables(self):
480+
height, width = 3, 3
481+
with self.test_session():
482+
images = tf.random_uniform((5, height, width, 3), seed=1)
483+
ops.batch_norm(images)
484+
beta = variables.get_variables_by_name('beta')[0]
485+
self.assertEquals(beta.op.name, 'BatchNorm/beta')
486+
gamma = variables.get_variables_by_name('gamma')
487+
self.assertEquals(gamma, [])
488+
moving_mean = tf.moving_average_variables()[0]
489+
moving_variance = tf.moving_average_variables()[1]
490+
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
491+
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
492+
493+
def testCreateVariablesWithScale(self):
479494
height, width = 3, 3
480495
with self.test_session():
481496
images = tf.random_uniform((5, height, width, 3), seed=1)
@@ -489,6 +504,34 @@ def testCreateVariables(self):
489504
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
490505
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
491506

507+
def testCreateVariablesWithoutCenterWithScale(self):
508+
height, width = 3, 3
509+
with self.test_session():
510+
images = tf.random_uniform((5, height, width, 3), seed=1)
511+
ops.batch_norm(images, center=False, scale=True)
512+
beta = variables.get_variables_by_name('beta')
513+
self.assertEquals(beta, [])
514+
gamma = variables.get_variables_by_name('gamma')[0]
515+
self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
516+
moving_mean = tf.moving_average_variables()[0]
517+
moving_variance = tf.moving_average_variables()[1]
518+
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
519+
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
520+
521+
def testCreateVariablesWithoutCenterWithoutScale(self):
522+
height, width = 3, 3
523+
with self.test_session():
524+
images = tf.random_uniform((5, height, width, 3), seed=1)
525+
ops.batch_norm(images, center=False, scale=False)
526+
beta = variables.get_variables_by_name('beta')
527+
self.assertEquals(beta, [])
528+
gamma = variables.get_variables_by_name('gamma')
529+
self.assertEquals(gamma, [])
530+
moving_mean = tf.moving_average_variables()[0]
531+
moving_variance = tf.moving_average_variables()[1]
532+
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
533+
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
534+
492535
def testMovingAverageVariables(self):
493536
height, width = 3, 3
494537
with self.test_session():

inception/inception/slim/scopes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def conv2d(*args, **kwargs)
5353
import contextlib
5454
import functools
5555

56+
5657
from tensorflow.python.framework import ops
5758

5859
_ARGSTACK_KEY = ("__arg_stack",)

inception/inception/slim/scopes_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020

21+
2122
import tensorflow as tf
2223
from inception.slim import scopes
2324

inception/inception/slim/variables.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@
8282
from __future__ import division
8383
from __future__ import print_function
8484

85+
8586
import tensorflow as tf
8687

88+
from tensorflow.core.framework import graph_pb2
8789
from inception.slim import scopes
8890

8991
# Collection containing all the variables created using slim.variables
@@ -171,6 +173,79 @@ def get_unique_variable(name):
171173
raise ValueError('Variable %s does not uniquely identify a variable', name)
172174

173175

176+
class VariableDeviceChooser(object):
177+
"""Slim device chooser for variables.
178+
179+
When using a parameter server it will assign them in a round-robin fashion.
180+
When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
181+
"""
182+
183+
def __init__(self,
184+
num_parameter_servers=0,
185+
ps_device='/job:ps',
186+
placement='CPU:0'):
187+
"""Initialize VariableDeviceChooser.
188+
189+
Args:
190+
num_parameter_servers: number of parameter servers.
191+
ps_device: string representing the parameter server device.
192+
placement: string representing the placement of the variable either CPU:0
193+
or GPU:0. When using parameter servers forced to CPU:0.
194+
"""
195+
self._num_ps = num_parameter_servers
196+
self._ps_device = ps_device
197+
self._placement = placement if num_parameter_servers == 0 else 'CPU:0'
198+
self._next_task_id = 0
199+
200+
def __call__(self, op):
201+
device_string = ''
202+
if self._num_ps > 0:
203+
task_id = self._next_task_id
204+
self._next_task_id = (self._next_task_id + 1) % self._num_ps
205+
device_string = '%s/task:%d' % (self._ps_device, task_id)
206+
device_string += '/%s' % self._placement
207+
return device_string
208+
209+
210+
# TODO(sguada) Remove once get_variable is able to colocate op.devices.
211+
def variable_device(device, name):
212+
"""Fix the variable device to colocate its ops."""
213+
if callable(device):
214+
var_name = tf.get_variable_scope().name + '/' + name
215+
var_def = graph_pb2.NodeDef(name=var_name, op='Variable')
216+
device = device(var_def)
217+
if device is None:
218+
device = ''
219+
return device
220+
221+
222+
@scopes.add_arg_scope
223+
def global_step(device=''):
224+
"""Returns the global step variable.
225+
226+
Args:
227+
device: Optional device to place the variable. It can be an string or a
228+
function that is called to get the device for the variable.
229+
230+
Returns:
231+
the tensor representing the global step variable.
232+
"""
233+
global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
234+
if global_step_ref:
235+
return global_step_ref[0]
236+
else:
237+
collections = [
238+
VARIABLES_TO_RESTORE,
239+
tf.GraphKeys.VARIABLES,
240+
tf.GraphKeys.GLOBAL_STEP,
241+
]
242+
# Get the device for the variable.
243+
with tf.device(variable_device(device, 'global_step')):
244+
return tf.get_variable('global_step', shape=[], dtype=tf.int64,
245+
initializer=tf.zeros_initializer,
246+
trainable=False, collections=collections)
247+
248+
174249
@scopes.add_arg_scope
175250
def variable(name, shape=None, dtype=tf.float32, initializer=None,
176251
regularizer=None, trainable=True, collections=None, device='',
@@ -200,9 +275,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
200275
Returns:
201276
The created or existing variable.
202277
"""
203-
# Instantiate the device for this variable if it is passed as a function.
204-
if device and callable(device):
205-
device = device()
206278
collections = list(collections or [])
207279

208280
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
@@ -212,7 +284,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
212284
collections.append(VARIABLES_TO_RESTORE)
213285
# Remove duplicates
214286
collections = set(collections)
215-
with tf.device(device):
287+
# Get the device for the variable.
288+
with tf.device(variable_device(device, name)):
216289
return tf.get_variable(name, shape=shape, dtype=dtype,
217290
initializer=initializer, regularizer=regularizer,
218291
trainable=trainable, collections=collections)

0 commit comments

Comments
 (0)