@@ -76,8 +76,6 @@ def multiclass_non_max_suppression(boxes,
76
76
a BoxList holding M boxes with a rank-1 scores field representing
77
77
corresponding scores for each box with scores sorted in decreasing order
78
78
and a rank-1 classes field representing a class label for each box.
79
- If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
80
- contain masks, keypoints, keypoint_heatmaps corresponding to boxes.
81
79
82
80
Raises:
83
81
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
@@ -174,6 +172,7 @@ def batch_multiclass_non_max_suppression(boxes,
174
172
change_coordinate_frame = False ,
175
173
num_valid_boxes = None ,
176
174
masks = None ,
175
+ additional_fields = None ,
177
176
scope = None ,
178
177
parallel_iterations = 32 ):
179
178
"""Multi-class version of non maximum suppression that operates on a batch.
@@ -203,11 +202,13 @@ def batch_multiclass_non_max_suppression(boxes,
203
202
is provided)
204
203
num_valid_boxes: (optional) a Tensor of type `int32`. A 1-D tensor of shape
205
204
[batch_size] representing the number of valid boxes to be considered
206
- for each image in the batch. This parameter allows for ignoring zero
207
- paddings.
205
+ for each image in the batch. This parameter allows for ignoring zero
206
+ paddings.
208
207
masks: (optional) a [batch_size, num_anchors, q, mask_height, mask_width]
209
208
float32 tensor containing box masks. `q` can be either number of classes
210
209
or 1 depending on whether a separate mask is predicted per class.
210
+ additional_fields: (optional) If not None, a dictionary that maps keys to
211
+ tensors whose dimensions are [batch_size, num_anchors, ...].
211
212
scope: tf scope name.
212
213
parallel_iterations: (optional) number of batch items to process in
213
214
parallel.
@@ -223,9 +224,13 @@ def batch_multiclass_non_max_suppression(boxes,
223
224
[batch_size, max_detections, mask_height, mask_width] float32 tensor
224
225
containing masks for each selected box. This is set to None if input
225
226
`masks` is None.
227
+ 'nmsed_additional_fields': (optional) a dictionary of
228
+ [batch_size, max_detections, ...] float32 tensors corresponding to the
229
+ tensors specified in the input `additional_fields`. This is not returned
230
+ if input `additional_fields` is None.
226
231
'num_detections': A [batch_size] int32 tensor indicating the number of
227
232
valid detections per batch item. Only the top num_detections[i] entries in
228
- nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
233
+ nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
229
234
entries are zero paddings.
230
235
231
236
Raises:
@@ -239,6 +244,7 @@ def batch_multiclass_non_max_suppression(boxes,
239
244
'to the third dimension of scores' )
240
245
241
246
original_masks = masks
247
+ original_additional_fields = additional_fields
242
248
with tf .name_scope (scope , 'BatchMultiClassNonMaxSuppression' ):
243
249
boxes_shape = boxes .shape
244
250
batch_size = boxes_shape [0 ].value
@@ -255,58 +261,135 @@ def batch_multiclass_non_max_suppression(boxes,
255
261
num_valid_boxes = tf .ones ([batch_size ], dtype = tf .int32 ) * num_anchors
256
262
257
263
# If masks aren't provided, create dummy masks so we can only have one copy
258
- # of single_image_nms_fn and discard the dummy masks after map_fn.
264
+ # of _single_image_nms_fn and discard the dummy masks after map_fn.
259
265
if masks is None :
260
266
masks_shape = tf .stack ([batch_size , num_anchors , 1 , 0 , 0 ])
261
267
masks = tf .zeros (masks_shape )
262
268
263
- def single_image_nms_fn (args ):
264
- """Runs NMS on a single image and returns padded output."""
265
- (per_image_boxes , per_image_scores , per_image_masks ,
266
- per_image_num_valid_boxes ) = args
269
+ if additional_fields is None :
270
+ additional_fields = {}
271
+
272
+ def _single_image_nms_fn (args ):
273
+ """Runs NMS on a single image and returns padded output.
274
+
275
+ Args:
276
+ args: A list of tensors consisting of the following:
277
+ per_image_boxes - A [num_anchors, q, 4] float32 tensor containing
278
+ detections. If `q` is 1 then same boxes are used for all classes
279
+ otherwise, if `q` is equal to number of classes, class-specific
280
+ boxes are used.
281
+ per_image_scores - A [num_anchors, num_classes] float32 tensor
282
+ containing the scores for each of the `num_anchors` detections.
283
+ per_image_masks - A [num_anchors, q, mask_height, mask_width] float32
284
+ tensor containing box masks. `q` can be either number of classes
285
+ or 1 depending on whether a separate mask is predicted per class.
286
+ per_image_additional_fields - (optional) A variable number of float32
287
+ tensors each with size [num_anchors, ...].
288
+ per_image_num_valid_boxes - A tensor of type `int32`. A 1-D tensor of
289
+ shape [batch_size] representing the number of valid boxes to be
290
+ considered for each image in the batch. This parameter allows for
291
+ ignoring zero paddings.
292
+
293
+ Returns:
294
+ 'nmsed_boxes': A [max_detections, 4] float32 tensor containing the
295
+ non-max suppressed boxes.
296
+ 'nmsed_scores': A [max_detections] float32 tensor containing the scores
297
+ for the boxes.
298
+ 'nmsed_classes': A [max_detections] float32 tensor containing the class
299
+ for boxes.
300
+ 'nmsed_masks': (optional) a [max_detections, mask_height, mask_width]
301
+ float32 tensor containing masks for each selected box. This is set to
302
+ None if input `masks` is None.
303
+ 'nmsed_additional_fields': (optional) A variable number of float32
304
+ tensors each with size [max_detections, ...] corresponding to the
305
+ input `per_image_additional_fields`.
306
+ 'num_detections': A [batch_size] int32 tensor indicating the number of
307
+ valid detections per batch item. Only the top num_detections[i]
308
+ entries in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The
309
+ rest of the entries are zero paddings.
310
+ """
311
+ per_image_boxes = args [0 ]
312
+ per_image_scores = args [1 ]
313
+ per_image_masks = args [2 ]
314
+ per_image_additional_fields = {
315
+ key : value
316
+ for key , value in zip (additional_fields , args [3 :- 1 ])
317
+ }
318
+ per_image_num_valid_boxes = args [- 1 ]
267
319
per_image_boxes = tf .reshape (
268
320
tf .slice (per_image_boxes , 3 * [0 ],
269
321
tf .stack ([per_image_num_valid_boxes , - 1 , - 1 ])), [- 1 , q , 4 ])
270
322
per_image_scores = tf .reshape (
271
323
tf .slice (per_image_scores , [0 , 0 ],
272
324
tf .stack ([per_image_num_valid_boxes , - 1 ])),
273
325
[- 1 , num_classes ])
274
-
275
326
per_image_masks = tf .reshape (
276
327
tf .slice (per_image_masks , 4 * [0 ],
277
328
tf .stack ([per_image_num_valid_boxes , - 1 , - 1 , - 1 ])),
278
329
[- 1 , q , per_image_masks .shape [2 ].value ,
279
330
per_image_masks .shape [3 ].value ])
331
+ if per_image_additional_fields is not None :
332
+ for key , tensor in per_image_additional_fields .items ():
333
+ additional_field_shape = tensor .get_shape ()
334
+ additional_field_dim = len (additional_field_shape )
335
+ per_image_additional_fields [key ] = tf .reshape (
336
+ tf .slice (per_image_additional_fields [key ],
337
+ additional_field_dim * [0 ],
338
+ tf .stack ([per_image_num_valid_boxes ] +
339
+ (additional_field_dim - 1 ) * [- 1 ])),
340
+ [- 1 ] + [dim .value for dim in additional_field_shape [1 :]])
280
341
nmsed_boxlist = multiclass_non_max_suppression (
281
342
per_image_boxes ,
282
343
per_image_scores ,
283
344
score_thresh ,
284
345
iou_thresh ,
285
346
max_size_per_class ,
286
347
max_total_size ,
287
- masks = per_image_masks ,
288
348
clip_window = clip_window ,
289
- change_coordinate_frame = change_coordinate_frame )
349
+ change_coordinate_frame = change_coordinate_frame ,
350
+ masks = per_image_masks ,
351
+ additional_fields = per_image_additional_fields )
290
352
padded_boxlist = box_list_ops .pad_or_clip_box_list (nmsed_boxlist ,
291
353
max_total_size )
292
354
num_detections = nmsed_boxlist .num_boxes ()
293
355
nmsed_boxes = padded_boxlist .get ()
294
356
nmsed_scores = padded_boxlist .get_field (fields .BoxListFields .scores )
295
357
nmsed_classes = padded_boxlist .get_field (fields .BoxListFields .classes )
296
358
nmsed_masks = padded_boxlist .get_field (fields .BoxListFields .masks )
297
- return [nmsed_boxes , nmsed_scores , nmsed_classes , nmsed_masks ,
298
- num_detections ]
359
+ nmsed_additional_fields = [
360
+ padded_boxlist .get_field (key ) for key in per_image_additional_fields
361
+ ]
362
+ return ([nmsed_boxes , nmsed_scores , nmsed_classes , nmsed_masks ] +
363
+ nmsed_additional_fields + [num_detections ])
364
+
365
+ num_additional_fields = 0
366
+ if additional_fields is not None :
367
+ num_additional_fields = len (additional_fields )
368
+ num_nmsed_outputs = 4 + num_additional_fields
299
369
300
- (batch_nmsed_boxes , batch_nmsed_scores ,
301
- batch_nmsed_classes , batch_nmsed_masks ,
302
- batch_num_detections ) = tf .map_fn (
303
- single_image_nms_fn ,
304
- elems = [boxes , scores , masks , num_valid_boxes ],
305
- dtype = [tf .float32 , tf .float32 , tf .float32 , tf .float32 , tf .int32 ],
306
- parallel_iterations = parallel_iterations )
370
+ batch_outputs = tf .map_fn (
371
+ _single_image_nms_fn ,
372
+ elems = ([boxes , scores , masks ] + list (additional_fields .values ()) +
373
+ [num_valid_boxes ]),
374
+ dtype = (num_nmsed_outputs * [tf .float32 ] + [tf .int32 ]),
375
+ parallel_iterations = parallel_iterations )
376
+
377
+ batch_nmsed_boxes = batch_outputs [0 ]
378
+ batch_nmsed_scores = batch_outputs [1 ]
379
+ batch_nmsed_classes = batch_outputs [2 ]
380
+ batch_nmsed_masks = batch_outputs [3 ]
381
+ batch_nmsed_additional_fields = {
382
+ key : value
383
+ for key , value in zip (additional_fields , batch_outputs [4 :- 1 ])
384
+ }
385
+ batch_num_detections = batch_outputs [- 1 ]
307
386
308
387
if original_masks is None :
309
388
batch_nmsed_masks = None
310
389
390
+ if original_additional_fields is None :
391
+ batch_nmsed_additional_fields = None
392
+
311
393
return (batch_nmsed_boxes , batch_nmsed_scores , batch_nmsed_classes ,
312
- batch_nmsed_masks , batch_num_detections )
394
+ batch_nmsed_masks , batch_nmsed_additional_fields ,
395
+ batch_num_detections )
0 commit comments