@@ -220,28 +220,23 @@ def train(dataset):
220
220
# Number of classes in the Dataset label set plus 1.
221
221
# Label 0 is reserved for an (unused) background class.
222
222
num_classes = dataset .num_classes () + 1
223
+
224
+ # Split the batch of images and labels for towers.
225
+ images_splits = tf .split (0 , FLAGS .num_gpus , images )
226
+ labels_splits = tf .split (0 , FLAGS .num_gpus , labels )
223
227
224
228
# Calculate the gradients for each model tower.
225
229
tower_grads = []
226
230
for i in xrange (FLAGS .num_gpus ):
227
231
with tf .device ('/gpu:%d' % i ):
228
232
with tf .name_scope ('%s_%d' % (inception .TOWER_NAME , i )) as scope :
229
- # Split the batch of images and labels.
230
- batch_start = split_batch_size * i
231
- images_batch = tf .slice (images ,
232
- begin = [batch_start , 0 , 0 , 0 ],
233
- size = [split_batch_size , - 1 , - 1 , - 1 ])
234
- labels_batch = tf .slice (labels ,
235
- begin = [batch_start ],
236
- size = [split_batch_size ])
237
-
238
-
239
233
# Force all Variables to reside on the CPU.
240
234
with slim .arg_scope ([slim .variables .variable ], device = '/cpu:0' ):
241
235
# Calculate the loss for one tower of the ImageNet model. This
242
236
# function constructs the entire ImageNet model but shares the
243
237
# variables across all towers.
244
- loss = _tower_loss (images_batch , labels_batch , num_classes , scope )
238
+ loss = _tower_loss (images_splits [i ], labels_splits [i ], num_classes ,
239
+ scope )
245
240
246
241
# Reuse variables for the next tower.
247
242
tf .get_variable_scope ().reuse_variables ()
0 commit comments