Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ff88581

Browse files
authored
Merge pull request tensorflow#2629 from tombstone/meta_arch_update
update post_processing module, builders, and meta architectures.
2 parents 018e62f + aeeaf9a commit ff88581

12 files changed

+1026
-265
lines changed

research/object_detection/builders/post_processing_builder.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def build(post_processing_config):
2828
configuration.
2929
3030
Non-max suppression callable takes `boxes`, `scores`, and optionally
31-
`clip_window`, `parallel_iterations` and `scope` as inputs. It returns
32-
`nms_boxes`, `nms_scores`, `nms_nms_classes` and `num_detections`. See
31+
`clip_window`, `parallel_iterations` `masks, and `scope` as inputs. It returns
32+
`nms_boxes`, `nms_scores`, `nms_classes` `nms_masks` and `num_detections`. See
3333
post_processing.batch_multiclass_non_max_suppression for the type and shape
3434
of these tensors.
3535
@@ -55,7 +55,8 @@ def build(post_processing_config):
5555
non_max_suppressor_fn = _build_non_max_suppressor(
5656
post_processing_config.batch_non_max_suppression)
5757
score_converter_fn = _build_score_converter(
58-
post_processing_config.score_converter)
58+
post_processing_config.score_converter,
59+
post_processing_config.logit_scale)
5960
return non_max_suppressor_fn, score_converter_fn
6061

6162

@@ -87,14 +88,25 @@ def _build_non_max_suppressor(nms_config):
8788
return non_max_suppressor_fn
8889

8990

90-
def _build_score_converter(score_converter_config):
91+
def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
92+
"""Create a function to scale logits then apply a Tensorflow function."""
93+
def score_converter_fn(logits):
94+
scaled_logits = tf.divide(logits, logit_scale, name='scale_logits')
95+
return tf_score_converter_fn(scaled_logits, name='convert_scores')
96+
score_converter_fn.__name__ = '%s_with_logit_scale' % (
97+
tf_score_converter_fn.__name__)
98+
return score_converter_fn
99+
100+
101+
def _build_score_converter(score_converter_config, logit_scale):
91102
"""Builds score converter based on the config.
92103
93104
Builds one of [tf.identity, tf.sigmoid, tf.softmax] score converters based on
94105
the config.
95106
96107
Args:
97108
score_converter_config: post_processing_pb2.PostProcessing.score_converter.
109+
logit_scale: temperature to use for SOFTMAX score_converter.
98110
99111
Returns:
100112
Callable score converter op.
@@ -103,9 +115,9 @@ def _build_score_converter(score_converter_config):
103115
ValueError: On unknown score converter.
104116
"""
105117
if score_converter_config == post_processing_pb2.PostProcessing.IDENTITY:
106-
return tf.identity
118+
return _score_converter_fn_with_logit_scale(tf.identity, logit_scale)
107119
if score_converter_config == post_processing_pb2.PostProcessing.SIGMOID:
108-
return tf.sigmoid
120+
return _score_converter_fn_with_logit_scale(tf.sigmoid, logit_scale)
109121
if score_converter_config == post_processing_pb2.PostProcessing.SOFTMAX:
110-
return tf.nn.softmax
122+
return _score_converter_fn_with_logit_scale(tf.nn.softmax, logit_scale)
111123
raise ValueError('Unknown score converter.')

research/object_detection/builders/post_processing_builder_test.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,31 @@ def test_build_identity_score_converter(self):
4848
post_processing_config = post_processing_pb2.PostProcessing()
4949
text_format.Merge(post_processing_text_proto, post_processing_config)
5050
_, score_converter = post_processing_builder.build(post_processing_config)
51-
self.assertEqual(score_converter, tf.identity)
51+
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
52+
53+
inputs = tf.constant([1, 1], tf.float32)
54+
outputs = score_converter(inputs)
55+
with self.test_session() as sess:
56+
converted_scores = sess.run(outputs)
57+
expected_converted_scores = sess.run(inputs)
58+
self.assertAllClose(converted_scores, expected_converted_scores)
59+
60+
def test_build_identity_score_converter_with_logit_scale(self):
61+
post_processing_text_proto = """
62+
score_converter: IDENTITY
63+
logit_scale: 2.0
64+
"""
65+
post_processing_config = post_processing_pb2.PostProcessing()
66+
text_format.Merge(post_processing_text_proto, post_processing_config)
67+
_, score_converter = post_processing_builder.build(post_processing_config)
68+
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
69+
70+
inputs = tf.constant([1, 1], tf.float32)
71+
outputs = score_converter(inputs)
72+
with self.test_session() as sess:
73+
converted_scores = sess.run(outputs)
74+
expected_converted_scores = sess.run(tf.constant([.5, .5], tf.float32))
75+
self.assertAllClose(converted_scores, expected_converted_scores)
5276

5377
def test_build_sigmoid_score_converter(self):
5478
post_processing_text_proto = """
@@ -57,7 +81,7 @@ def test_build_sigmoid_score_converter(self):
5781
post_processing_config = post_processing_pb2.PostProcessing()
5882
text_format.Merge(post_processing_text_proto, post_processing_config)
5983
_, score_converter = post_processing_builder.build(post_processing_config)
60-
self.assertEqual(score_converter, tf.sigmoid)
84+
self.assertEqual(score_converter.__name__, 'sigmoid_with_logit_scale')
6185

6286
def test_build_softmax_score_converter(self):
6387
post_processing_text_proto = """
@@ -66,7 +90,17 @@ def test_build_softmax_score_converter(self):
6690
post_processing_config = post_processing_pb2.PostProcessing()
6791
text_format.Merge(post_processing_text_proto, post_processing_config)
6892
_, score_converter = post_processing_builder.build(post_processing_config)
69-
self.assertEqual(score_converter, tf.nn.softmax)
93+
self.assertEqual(score_converter.__name__, 'softmax_with_logit_scale')
94+
95+
def test_build_softmax_score_converter_with_temperature(self):
96+
post_processing_text_proto = """
97+
score_converter: SOFTMAX
98+
logit_scale: 2.0
99+
"""
100+
post_processing_config = post_processing_pb2.PostProcessing()
101+
text_format.Merge(post_processing_text_proto, post_processing_config)
102+
_, score_converter = post_processing_builder.build(post_processing_config)
103+
self.assertEqual(score_converter.__name__, 'softmax_with_logit_scale')
70104

71105

72106
if __name__ == '__main__':

research/object_detection/core/post_processing.py

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def multiclass_non_max_suppression(boxes,
7676
a BoxList holding M boxes with a rank-1 scores field representing
7777
corresponding scores for each box with scores sorted in decreasing order
7878
and a rank-1 classes field representing a class label for each box.
79-
If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
80-
contain masks, keypoints, keypoint_heatmaps corresponding to boxes.
8179
8280
Raises:
8381
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
@@ -174,6 +172,7 @@ def batch_multiclass_non_max_suppression(boxes,
174172
change_coordinate_frame=False,
175173
num_valid_boxes=None,
176174
masks=None,
175+
additional_fields=None,
177176
scope=None,
178177
parallel_iterations=32):
179178
"""Multi-class version of non maximum suppression that operates on a batch.
@@ -203,11 +202,13 @@ def batch_multiclass_non_max_suppression(boxes,
203202
is provided)
204203
num_valid_boxes: (optional) a Tensor of type `int32`. A 1-D tensor of shape
205204
[batch_size] representing the number of valid boxes to be considered
206-
for each image in the batch. This parameter allows for ignoring zero
207-
paddings.
205+
for each image in the batch. This parameter allows for ignoring zero
206+
paddings.
208207
masks: (optional) a [batch_size, num_anchors, q, mask_height, mask_width]
209208
float32 tensor containing box masks. `q` can be either number of classes
210209
or 1 depending on whether a separate mask is predicted per class.
210+
additional_fields: (optional) If not None, a dictionary that maps keys to
211+
tensors whose dimensions are [batch_size, num_anchors, ...].
211212
scope: tf scope name.
212213
parallel_iterations: (optional) number of batch items to process in
213214
parallel.
@@ -223,9 +224,13 @@ def batch_multiclass_non_max_suppression(boxes,
223224
[batch_size, max_detections, mask_height, mask_width] float32 tensor
224225
containing masks for each selected box. This is set to None if input
225226
`masks` is None.
227+
'nmsed_additional_fields': (optional) a dictionary of
228+
[batch_size, max_detections, ...] float32 tensors corresponding to the
229+
tensors specified in the input `additional_fields`. This is not returned
230+
if input `additional_fields` is None.
226231
'num_detections': A [batch_size] int32 tensor indicating the number of
227232
valid detections per batch item. Only the top num_detections[i] entries in
228-
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
233+
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
229234
entries are zero paddings.
230235
231236
Raises:
@@ -239,6 +244,7 @@ def batch_multiclass_non_max_suppression(boxes,
239244
'to the third dimension of scores')
240245

241246
original_masks = masks
247+
original_additional_fields = additional_fields
242248
with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
243249
boxes_shape = boxes.shape
244250
batch_size = boxes_shape[0].value
@@ -255,58 +261,135 @@ def batch_multiclass_non_max_suppression(boxes,
255261
num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors
256262

257263
# If masks aren't provided, create dummy masks so we can only have one copy
258-
# of single_image_nms_fn and discard the dummy masks after map_fn.
264+
# of _single_image_nms_fn and discard the dummy masks after map_fn.
259265
if masks is None:
260266
masks_shape = tf.stack([batch_size, num_anchors, 1, 0, 0])
261267
masks = tf.zeros(masks_shape)
262268

263-
def single_image_nms_fn(args):
264-
"""Runs NMS on a single image and returns padded output."""
265-
(per_image_boxes, per_image_scores, per_image_masks,
266-
per_image_num_valid_boxes) = args
269+
if additional_fields is None:
270+
additional_fields = {}
271+
272+
def _single_image_nms_fn(args):
273+
"""Runs NMS on a single image and returns padded output.
274+
275+
Args:
276+
args: A list of tensors consisting of the following:
277+
per_image_boxes - A [num_anchors, q, 4] float32 tensor containing
278+
detections. If `q` is 1 then same boxes are used for all classes
279+
otherwise, if `q` is equal to number of classes, class-specific
280+
boxes are used.
281+
per_image_scores - A [num_anchors, num_classes] float32 tensor
282+
containing the scores for each of the `num_anchors` detections.
283+
per_image_masks - A [num_anchors, q, mask_height, mask_width] float32
284+
tensor containing box masks. `q` can be either number of classes
285+
or 1 depending on whether a separate mask is predicted per class.
286+
per_image_additional_fields - (optional) A variable number of float32
287+
tensors each with size [num_anchors, ...].
288+
per_image_num_valid_boxes - A tensor of type `int32`. A 1-D tensor of
289+
shape [batch_size] representing the number of valid boxes to be
290+
considered for each image in the batch. This parameter allows for
291+
ignoring zero paddings.
292+
293+
Returns:
294+
'nmsed_boxes': A [max_detections, 4] float32 tensor containing the
295+
non-max suppressed boxes.
296+
'nmsed_scores': A [max_detections] float32 tensor containing the scores
297+
for the boxes.
298+
'nmsed_classes': A [max_detections] float32 tensor containing the class
299+
for boxes.
300+
'nmsed_masks': (optional) a [max_detections, mask_height, mask_width]
301+
float32 tensor containing masks for each selected box. This is set to
302+
None if input `masks` is None.
303+
'nmsed_additional_fields': (optional) A variable number of float32
304+
tensors each with size [max_detections, ...] corresponding to the
305+
input `per_image_additional_fields`.
306+
'num_detections': A [batch_size] int32 tensor indicating the number of
307+
valid detections per batch item. Only the top num_detections[i]
308+
entries in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The
309+
rest of the entries are zero paddings.
310+
"""
311+
per_image_boxes = args[0]
312+
per_image_scores = args[1]
313+
per_image_masks = args[2]
314+
per_image_additional_fields = {
315+
key: value
316+
for key, value in zip(additional_fields, args[3:-1])
317+
}
318+
per_image_num_valid_boxes = args[-1]
267319
per_image_boxes = tf.reshape(
268320
tf.slice(per_image_boxes, 3 * [0],
269321
tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
270322
per_image_scores = tf.reshape(
271323
tf.slice(per_image_scores, [0, 0],
272324
tf.stack([per_image_num_valid_boxes, -1])),
273325
[-1, num_classes])
274-
275326
per_image_masks = tf.reshape(
276327
tf.slice(per_image_masks, 4 * [0],
277328
tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
278329
[-1, q, per_image_masks.shape[2].value,
279330
per_image_masks.shape[3].value])
331+
if per_image_additional_fields is not None:
332+
for key, tensor in per_image_additional_fields.items():
333+
additional_field_shape = tensor.get_shape()
334+
additional_field_dim = len(additional_field_shape)
335+
per_image_additional_fields[key] = tf.reshape(
336+
tf.slice(per_image_additional_fields[key],
337+
additional_field_dim * [0],
338+
tf.stack([per_image_num_valid_boxes] +
339+
(additional_field_dim - 1) * [-1])),
340+
[-1] + [dim.value for dim in additional_field_shape[1:]])
280341
nmsed_boxlist = multiclass_non_max_suppression(
281342
per_image_boxes,
282343
per_image_scores,
283344
score_thresh,
284345
iou_thresh,
285346
max_size_per_class,
286347
max_total_size,
287-
masks=per_image_masks,
288348
clip_window=clip_window,
289-
change_coordinate_frame=change_coordinate_frame)
349+
change_coordinate_frame=change_coordinate_frame,
350+
masks=per_image_masks,
351+
additional_fields=per_image_additional_fields)
290352
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
291353
max_total_size)
292354
num_detections = nmsed_boxlist.num_boxes()
293355
nmsed_boxes = padded_boxlist.get()
294356
nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
295357
nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
296358
nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
297-
return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
298-
num_detections]
359+
nmsed_additional_fields = [
360+
padded_boxlist.get_field(key) for key in per_image_additional_fields
361+
]
362+
return ([nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks] +
363+
nmsed_additional_fields + [num_detections])
364+
365+
num_additional_fields = 0
366+
if additional_fields is not None:
367+
num_additional_fields = len(additional_fields)
368+
num_nmsed_outputs = 4 + num_additional_fields
299369

300-
(batch_nmsed_boxes, batch_nmsed_scores,
301-
batch_nmsed_classes, batch_nmsed_masks,
302-
batch_num_detections) = tf.map_fn(
303-
single_image_nms_fn,
304-
elems=[boxes, scores, masks, num_valid_boxes],
305-
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
306-
parallel_iterations=parallel_iterations)
370+
batch_outputs = tf.map_fn(
371+
_single_image_nms_fn,
372+
elems=([boxes, scores, masks] + list(additional_fields.values()) +
373+
[num_valid_boxes]),
374+
dtype=(num_nmsed_outputs * [tf.float32] + [tf.int32]),
375+
parallel_iterations=parallel_iterations)
376+
377+
batch_nmsed_boxes = batch_outputs[0]
378+
batch_nmsed_scores = batch_outputs[1]
379+
batch_nmsed_classes = batch_outputs[2]
380+
batch_nmsed_masks = batch_outputs[3]
381+
batch_nmsed_additional_fields = {
382+
key: value
383+
for key, value in zip(additional_fields, batch_outputs[4:-1])
384+
}
385+
batch_num_detections = batch_outputs[-1]
307386

308387
if original_masks is None:
309388
batch_nmsed_masks = None
310389

390+
if original_additional_fields is None:
391+
batch_nmsed_additional_fields = None
392+
311393
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
312-
batch_nmsed_masks, batch_num_detections)
394+
batch_nmsed_masks, batch_nmsed_additional_fields,
395+
batch_num_detections)

0 commit comments

Comments
 (0)