82
82
from __future__ import division
83
83
from __future__ import print_function
84
84
85
+
85
86
import tensorflow as tf
86
87
88
+ from tensorflow .core .framework import graph_pb2
87
89
from inception .slim import scopes
88
90
89
91
# Collection containing all the variables created using slim.variables
@@ -171,6 +173,79 @@ def get_unique_variable(name):
171
173
raise ValueError ('Variable %s does not uniquely identify a variable' , name )
172
174
173
175
176
+ class VariableDeviceChooser (object ):
177
+ """Slim device chooser for variables.
178
+
179
+ When using a parameter server it will assign them in a round-robin fashion.
180
+ When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
181
+ """
182
+
183
+ def __init__ (self ,
184
+ num_parameter_servers = 0 ,
185
+ ps_device = '/job:ps' ,
186
+ placement = 'CPU:0' ):
187
+ """Initialize VariableDeviceChooser.
188
+
189
+ Args:
190
+ num_parameter_servers: number of parameter servers.
191
+ ps_device: string representing the parameter server device.
192
+ placement: string representing the placement of the variable either CPU:0
193
+ or GPU:0. When using parameter servers forced to CPU:0.
194
+ """
195
+ self ._num_ps = num_parameter_servers
196
+ self ._ps_device = ps_device
197
+ self ._placement = placement if num_parameter_servers == 0 else 'CPU:0'
198
+ self ._next_task_id = 0
199
+
200
+ def __call__ (self , op ):
201
+ device_string = ''
202
+ if self ._num_ps > 0 :
203
+ task_id = self ._next_task_id
204
+ self ._next_task_id = (self ._next_task_id + 1 ) % self ._num_ps
205
+ device_string = '%s/task:%d' % (self ._ps_device , task_id )
206
+ device_string += '/%s' % self ._placement
207
+ return device_string
208
+
209
+
210
+ # TODO(sguada) Remove once get_variable is able to colocate op.devices.
211
+ def variable_device (device , name ):
212
+ """Fix the variable device to colocate its ops."""
213
+ if callable (device ):
214
+ var_name = tf .get_variable_scope ().name + '/' + name
215
+ var_def = graph_pb2 .NodeDef (name = var_name , op = 'Variable' )
216
+ device = device (var_def )
217
+ if device is None :
218
+ device = ''
219
+ return device
220
+
221
+
222
+ @scopes .add_arg_scope
223
+ def global_step (device = '' ):
224
+ """Returns the global step variable.
225
+
226
+ Args:
227
+ device: Optional device to place the variable. It can be an string or a
228
+ function that is called to get the device for the variable.
229
+
230
+ Returns:
231
+ the tensor representing the global step variable.
232
+ """
233
+ global_step_ref = tf .get_collection (tf .GraphKeys .GLOBAL_STEP )
234
+ if global_step_ref :
235
+ return global_step_ref [0 ]
236
+ else :
237
+ collections = [
238
+ VARIABLES_TO_RESTORE ,
239
+ tf .GraphKeys .VARIABLES ,
240
+ tf .GraphKeys .GLOBAL_STEP ,
241
+ ]
242
+ # Get the device for the variable.
243
+ with tf .device (variable_device (device , 'global_step' )):
244
+ return tf .get_variable ('global_step' , shape = [], dtype = tf .int64 ,
245
+ initializer = tf .zeros_initializer ,
246
+ trainable = False , collections = collections )
247
+
248
+
174
249
@scopes .add_arg_scope
175
250
def variable (name , shape = None , dtype = tf .float32 , initializer = None ,
176
251
regularizer = None , trainable = True , collections = None , device = '' ,
@@ -200,9 +275,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
200
275
Returns:
201
276
The created or existing variable.
202
277
"""
203
- # Instantiate the device for this variable if it is passed as a function.
204
- if device and callable (device ):
205
- device = device ()
206
278
collections = list (collections or [])
207
279
208
280
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
@@ -212,7 +284,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
212
284
collections .append (VARIABLES_TO_RESTORE )
213
285
# Remove duplicates
214
286
collections = set (collections )
215
- with tf .device (device ):
287
+ # Get the device for the variable.
288
+ with tf .device (variable_device (device , name )):
216
289
return tf .get_variable (name , shape = shape , dtype = dtype ,
217
290
initializer = initializer , regularizer = regularizer ,
218
291
trainable = trainable , collections = collections )
0 commit comments