Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4e92bc5

Browse files
author
Jonathan Huang
authored
Merge pull request tensorflow#2639 from tombstone/data
Fixes tensorflow#2634
2 parents 141ed95 + ce41762 commit 4e92bc5

File tree

3 files changed

+208
-50
lines changed

3 files changed

+208
-50
lines changed

research/object_detection/data_decoders/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ py_library(
1414
"//tensorflow",
1515
"//tensorflow_models/object_detection/core:data_decoder",
1616
"//tensorflow_models/object_detection/core:standard_fields",
17+
"//tensorflow_models/object_detection/utils:label_map_util",
1718
],
1819
)
1920

research/object_detection/data_decoders/tf_example_decoder.py

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,67 @@
2222

2323
from object_detection.core import data_decoder
2424
from object_detection.core import standard_fields as fields
25+
from object_detection.utils import label_map_util
2526

2627
slim_example_decoder = tf.contrib.slim.tfexample_decoder
2728

2829

2930
class TfExampleDecoder(data_decoder.DataDecoder):
3031
"""Tensorflow Example proto decoder."""
3132

32-
def __init__(self):
33-
"""Constructor sets keys_to_features and items_to_handlers."""
33+
def __init__(self,
34+
load_instance_masks=False,
35+
label_map_proto_file=None,
36+
use_display_name=False):
37+
"""Constructor sets keys_to_features and items_to_handlers.
38+
39+
Args:
40+
load_instance_masks: whether or not to load and handle instance masks.
41+
label_map_proto_file: a file path to a
42+
object_detection.protos.StringIntLabelMap proto. If provided, then the
43+
mapped IDs of 'image/object/class/text' will take precedence over the
44+
existing 'image/object/class/label' ID. Also, if provided, it is
45+
assumed that 'image/object/class/text' will be in the data.
46+
use_display_name: whether or not to use the `display_name` for label
47+
mapping (instead of `name`). Only used if label_map_proto_file is
48+
provided.
49+
"""
3450
self.keys_to_features = {
35-
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
36-
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
37-
'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
38-
'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
39-
'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''),
40-
'image/height': tf.FixedLenFeature((), tf.int64, 1),
41-
'image/width': tf.FixedLenFeature((), tf.int64, 1),
51+
'image/encoded':
52+
tf.FixedLenFeature((), tf.string, default_value=''),
53+
'image/format':
54+
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
55+
'image/filename':
56+
tf.FixedLenFeature((), tf.string, default_value=''),
57+
'image/key/sha256':
58+
tf.FixedLenFeature((), tf.string, default_value=''),
59+
'image/source_id':
60+
tf.FixedLenFeature((), tf.string, default_value=''),
61+
'image/height':
62+
tf.FixedLenFeature((), tf.int64, 1),
63+
'image/width':
64+
tf.FixedLenFeature((), tf.int64, 1),
4265
# Object boxes and classes.
43-
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
44-
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
45-
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
46-
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
47-
'image/object/class/label': tf.VarLenFeature(tf.int64),
48-
'image/object/area': tf.VarLenFeature(tf.float32),
49-
'image/object/is_crowd': tf.VarLenFeature(tf.int64),
50-
'image/object/difficult': tf.VarLenFeature(tf.int64),
51-
# Instance masks and classes.
52-
'image/segmentation/object': tf.VarLenFeature(tf.int64),
53-
'image/segmentation/object/class': tf.VarLenFeature(tf.int64)
66+
'image/object/bbox/xmin':
67+
tf.VarLenFeature(tf.float32),
68+
'image/object/bbox/xmax':
69+
tf.VarLenFeature(tf.float32),
70+
'image/object/bbox/ymin':
71+
tf.VarLenFeature(tf.float32),
72+
'image/object/bbox/ymax':
73+
tf.VarLenFeature(tf.float32),
74+
'image/object/class/label':
75+
tf.VarLenFeature(tf.int64),
76+
'image/object/class/text':
77+
tf.VarLenFeature(tf.string),
78+
'image/object/area':
79+
tf.VarLenFeature(tf.float32),
80+
'image/object/is_crowd':
81+
tf.VarLenFeature(tf.int64),
82+
'image/object/difficult':
83+
tf.VarLenFeature(tf.int64),
84+
'image/object/group_of':
85+
tf.VarLenFeature(tf.int64),
5486
}
5587
self.items_to_handlers = {
5688
fields.InputDataFields.image: slim_example_decoder.Image(
@@ -65,22 +97,42 @@ def __init__(self):
6597
fields.InputDataFields.groundtruth_boxes: (
6698
slim_example_decoder.BoundingBox(
6799
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
68-
fields.InputDataFields.groundtruth_classes: (
69-
slim_example_decoder.Tensor('image/object/class/label')),
70100
fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor(
71101
'image/object/area'),
72102
fields.InputDataFields.groundtruth_is_crowd: (
73103
slim_example_decoder.Tensor('image/object/is_crowd')),
74104
fields.InputDataFields.groundtruth_difficult: (
75105
slim_example_decoder.Tensor('image/object/difficult')),
76-
# Instance masks and classes.
77-
fields.InputDataFields.groundtruth_instance_masks: (
78-
slim_example_decoder.ItemHandlerCallback(
79-
['image/segmentation/object', 'image/height', 'image/width'],
80-
self._reshape_instance_masks)),
81-
fields.InputDataFields.groundtruth_instance_classes: (
82-
slim_example_decoder.Tensor('image/segmentation/object/class')),
106+
fields.InputDataFields.groundtruth_group_of: (
107+
slim_example_decoder.Tensor('image/object/group_of'))
83108
}
109+
if load_instance_masks:
110+
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.float32)
111+
self.items_to_handlers[
112+
fields.InputDataFields.groundtruth_instance_masks] = (
113+
slim_example_decoder.ItemHandlerCallback(
114+
['image/object/mask', 'image/height', 'image/width'],
115+
self._reshape_instance_masks))
116+
if label_map_proto_file:
117+
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
118+
use_display_name)
119+
# We use a default_value of -1, but we expect all labels to be contained
120+
# in the label map.
121+
table = tf.contrib.lookup.HashTable(
122+
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
123+
keys=tf.constant(list(label_map.keys())),
124+
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
125+
default_value=-1)
126+
# If the label_map_proto is provided, try to use it in conjunction with
127+
# the class text, and fall back to a materialized ID.
128+
label_handler = slim_example_decoder.BackupHandler(
129+
slim_example_decoder.LookupTensor(
130+
'image/object/class/text', table, default_value=''),
131+
slim_example_decoder.Tensor('image/object/class/label'))
132+
else:
133+
label_handler = slim_example_decoder.Tensor('image/object/class/label')
134+
self.items_to_handlers[
135+
fields.InputDataFields.groundtruth_classes] = label_handler
84136

85137
def decode(self, tf_example_string_tensor):
86138
"""Decodes serialized tensorflow example and returns a tensor dictionary.
@@ -106,14 +158,14 @@ def decode(self, tf_example_string_tensor):
106158
[None] containing containing object mask area in pixel squared.
107159
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
108160
[None] indicating if the boxes enclose a crowd.
161+
Optional:
109162
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
110163
[None] indicating if the boxes represent `difficult` instances.
164+
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
165+
[None] indicating if the boxes represent `group_of` instances.
111166
fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
112167
shape [None, None, None] containing instance masks.
113-
fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
114-
of shape [None] containing classes for the instance masks.
115168
"""
116-
117169
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
118170
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
119171
self.items_to_handlers)
@@ -135,13 +187,14 @@ def _reshape_instance_masks(self, keys_to_tensors):
135187
keys_to_tensors: a dictionary from keys to tensors.
136188
137189
Returns:
138-
A 3-D boolean tensor of shape [num_instances, height, width].
190+
A 3-D float tensor of shape [num_instances, height, width] with values
191+
in {0, 1}.
139192
"""
140-
masks = keys_to_tensors['image/segmentation/object']
141-
if isinstance(masks, tf.SparseTensor):
142-
masks = tf.sparse_tensor_to_dense(masks)
143193
height = keys_to_tensors['image/height']
144194
width = keys_to_tensors['image/width']
145195
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
146-
147-
return tf.cast(tf.reshape(masks, to_shape), tf.bool)
196+
masks = keys_to_tensors['image/object/mask']
197+
if isinstance(masks, tf.SparseTensor):
198+
masks = tf.sparse_tensor_to_dense(masks)
199+
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
200+
return tf.cast(masks, tf.float32)

research/object_detection/data_decoders/tf_example_decoder_test.py

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Tests for object_detection.data_decoders.tf_example_decoder."""
1717

18+
import os
1819
import numpy as np
1920
import tensorflow as tf
2021

@@ -51,6 +52,8 @@ def _FloatFeature(self, value):
5152
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
5253

5354
def _BytesFeature(self, value):
55+
if isinstance(value, list):
56+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
5457
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
5558

5659
def testDecodeJpegImage(self):
@@ -165,6 +168,48 @@ def testDecodeObjectLabel(self):
165168
self.assertAllEqual(bbox_classes,
166169
tensor_dict[fields.InputDataFields.groundtruth_classes])
167170

171+
def testDecodeObjectLabelWithMapping(self):
172+
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
173+
encoded_jpeg = self._EncodeImage(image_tensor)
174+
bbox_classes_text = ['cat', 'dog']
175+
example = tf.train.Example(
176+
features=tf.train.Features(
177+
feature={
178+
'image/encoded':
179+
self._BytesFeature(encoded_jpeg),
180+
'image/format':
181+
self._BytesFeature('jpeg'),
182+
'image/object/class/text':
183+
self._BytesFeature(bbox_classes_text),
184+
})).SerializeToString()
185+
186+
label_map_string = """
187+
item {
188+
id:3
189+
name:'cat'
190+
}
191+
item {
192+
id:1
193+
name:'dog'
194+
}
195+
"""
196+
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
197+
with tf.gfile.Open(label_map_path, 'wb') as f:
198+
f.write(label_map_string)
199+
example_decoder = tf_example_decoder.TfExampleDecoder(
200+
label_map_proto_file=label_map_path)
201+
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
202+
203+
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
204+
.get_shape().as_list()), [None])
205+
206+
with self.test_session() as sess:
207+
sess.run(tf.tables_initializer())
208+
tensor_dict = sess.run(tensor_dict)
209+
210+
self.assertAllEqual([3, 1],
211+
tensor_dict[fields.InputDataFields.groundtruth_classes])
212+
168213
def testDecodeObjectArea(self):
169214
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
170215
encoded_jpeg = self._EncodeImage(image_tensor)
@@ -232,6 +277,30 @@ def testDecodeObjectDifficult(self):
232277
tensor_dict[
233278
fields.InputDataFields.groundtruth_difficult])
234279

280+
def testDecodeObjectGroupOf(self):
281+
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
282+
encoded_jpeg = self._EncodeImage(image_tensor)
283+
object_group_of = [0, 1]
284+
example = tf.train.Example(features=tf.train.Features(
285+
feature={
286+
'image/encoded': self._BytesFeature(encoded_jpeg),
287+
'image/format': self._BytesFeature('jpeg'),
288+
'image/object/group_of': self._Int64Feature(object_group_of),
289+
})).SerializeToString()
290+
291+
example_decoder = tf_example_decoder.TfExampleDecoder()
292+
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
293+
294+
self.assertAllEqual((tensor_dict[
295+
fields.InputDataFields.groundtruth_group_of].get_shape().as_list()),
296+
[None])
297+
with self.test_session() as sess:
298+
tensor_dict = sess.run(tensor_dict)
299+
300+
self.assertAllEqual(
301+
[bool(item) for item in object_group_of],
302+
tensor_dict[fields.InputDataFields.groundtruth_group_of])
303+
235304
def testDecodeInstanceSegmentation(self):
236305
num_instances = 4
237306
image_height = 5
@@ -244,44 +313,79 @@ def testDecodeInstanceSegmentation(self):
244313
encoded_jpeg = self._EncodeImage(image_tensor)
245314

246315
# Randomly generate instance segmentation masks.
247-
instance_segmentation = (
316+
instance_masks = (
248317
np.random.randint(2, size=(num_instances,
249318
image_height,
250-
image_width)).astype(np.int64))
319+
image_width)).astype(np.float32))
320+
instance_masks_flattened = np.reshape(instance_masks, [-1])
251321

252322
# Randomly generate class labels for each instance.
253-
instance_segmentation_classes = np.random.randint(
323+
object_classes = np.random.randint(
254324
100, size=(num_instances)).astype(np.int64)
255325

256326
example = tf.train.Example(features=tf.train.Features(feature={
257327
'image/encoded': self._BytesFeature(encoded_jpeg),
258328
'image/format': self._BytesFeature('jpeg'),
259329
'image/height': self._Int64Feature([image_height]),
260330
'image/width': self._Int64Feature([image_width]),
261-
'image/segmentation/object': self._Int64Feature(
262-
instance_segmentation.flatten()),
263-
'image/segmentation/object/class': self._Int64Feature(
264-
instance_segmentation_classes)})).SerializeToString()
265-
example_decoder = tf_example_decoder.TfExampleDecoder()
331+
'image/object/mask': self._FloatFeature(instance_masks_flattened),
332+
'image/object/class/label': self._Int64Feature(
333+
object_classes)})).SerializeToString()
334+
example_decoder = tf_example_decoder.TfExampleDecoder(
335+
load_instance_masks=True)
266336
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
267337

268338
self.assertAllEqual((
269339
tensor_dict[fields.InputDataFields.groundtruth_instance_masks].
270340
get_shape().as_list()), [None, None, None])
271341

272342
self.assertAllEqual((
273-
tensor_dict[fields.InputDataFields.groundtruth_instance_classes].
343+
tensor_dict[fields.InputDataFields.groundtruth_classes].
274344
get_shape().as_list()), [None])
275345

276346
with self.test_session() as sess:
277347
tensor_dict = sess.run(tensor_dict)
278348

279349
self.assertAllEqual(
280-
instance_segmentation.astype(np.bool),
350+
instance_masks.astype(np.float32),
281351
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
282352
self.assertAllEqual(
283-
instance_segmentation_classes,
284-
tensor_dict[fields.InputDataFields.groundtruth_instance_classes])
353+
object_classes,
354+
tensor_dict[fields.InputDataFields.groundtruth_classes])
355+
356+
def testInstancesNotAvailableByDefault(self):
357+
num_instances = 4
358+
image_height = 5
359+
image_width = 3
360+
# Randomly generate image.
361+
image_tensor = np.random.randint(255, size=(image_height,
362+
image_width,
363+
3)).astype(np.uint8)
364+
encoded_jpeg = self._EncodeImage(image_tensor)
365+
366+
# Randomly generate instance segmentation masks.
367+
instance_masks = (
368+
np.random.randint(2, size=(num_instances,
369+
image_height,
370+
image_width)).astype(np.float32))
371+
instance_masks_flattened = np.reshape(instance_masks, [-1])
372+
373+
# Randomly generate class labels for each instance.
374+
object_classes = np.random.randint(
375+
100, size=(num_instances)).astype(np.int64)
376+
377+
example = tf.train.Example(features=tf.train.Features(feature={
378+
'image/encoded': self._BytesFeature(encoded_jpeg),
379+
'image/format': self._BytesFeature('jpeg'),
380+
'image/height': self._Int64Feature([image_height]),
381+
'image/width': self._Int64Feature([image_width]),
382+
'image/object/mask': self._FloatFeature(instance_masks_flattened),
383+
'image/object/class/label': self._Int64Feature(
384+
object_classes)})).SerializeToString()
385+
example_decoder = tf_example_decoder.TfExampleDecoder()
386+
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
387+
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
388+
not in tensor_dict)
285389

286390

287391
if __name__ == '__main__':

0 commit comments

Comments
 (0)