22
22
23
23
from object_detection .core import data_decoder
24
24
from object_detection .core import standard_fields as fields
25
+ from object_detection .utils import label_map_util
25
26
26
27
slim_example_decoder = tf .contrib .slim .tfexample_decoder
27
28
28
29
29
30
class TfExampleDecoder (data_decoder .DataDecoder ):
30
31
"""Tensorflow Example proto decoder."""
31
32
32
- def __init__ (self ):
33
- """Constructor sets keys_to_features and items_to_handlers."""
33
+ def __init__ (self ,
34
+ load_instance_masks = False ,
35
+ label_map_proto_file = None ,
36
+ use_display_name = False ):
37
+ """Constructor sets keys_to_features and items_to_handlers.
38
+
39
+ Args:
40
+ load_instance_masks: whether or not to load and handle instance masks.
41
+ label_map_proto_file: a file path to a
42
+ object_detection.protos.StringIntLabelMap proto. If provided, then the
43
+ mapped IDs of 'image/object/class/text' will take precedence over the
44
+ existing 'image/object/class/label' ID. Also, if provided, it is
45
+ assumed that 'image/object/class/text' will be in the data.
46
+ use_display_name: whether or not to use the `display_name` for label
47
+ mapping (instead of `name`). Only used if label_map_proto_file is
48
+ provided.
49
+ """
34
50
self .keys_to_features = {
35
- 'image/encoded' : tf .FixedLenFeature ((), tf .string , default_value = '' ),
36
- 'image/format' : tf .FixedLenFeature ((), tf .string , default_value = 'jpeg' ),
37
- 'image/filename' : tf .FixedLenFeature ((), tf .string , default_value = '' ),
38
- 'image/key/sha256' : tf .FixedLenFeature ((), tf .string , default_value = '' ),
39
- 'image/source_id' : tf .FixedLenFeature ((), tf .string , default_value = '' ),
40
- 'image/height' : tf .FixedLenFeature ((), tf .int64 , 1 ),
41
- 'image/width' : tf .FixedLenFeature ((), tf .int64 , 1 ),
51
+ 'image/encoded' :
52
+ tf .FixedLenFeature ((), tf .string , default_value = '' ),
53
+ 'image/format' :
54
+ tf .FixedLenFeature ((), tf .string , default_value = 'jpeg' ),
55
+ 'image/filename' :
56
+ tf .FixedLenFeature ((), tf .string , default_value = '' ),
57
+ 'image/key/sha256' :
58
+ tf .FixedLenFeature ((), tf .string , default_value = '' ),
59
+ 'image/source_id' :
60
+ tf .FixedLenFeature ((), tf .string , default_value = '' ),
61
+ 'image/height' :
62
+ tf .FixedLenFeature ((), tf .int64 , 1 ),
63
+ 'image/width' :
64
+ tf .FixedLenFeature ((), tf .int64 , 1 ),
42
65
# Object boxes and classes.
43
- 'image/object/bbox/xmin' : tf .VarLenFeature (tf .float32 ),
44
- 'image/object/bbox/xmax' : tf .VarLenFeature (tf .float32 ),
45
- 'image/object/bbox/ymin' : tf .VarLenFeature (tf .float32 ),
46
- 'image/object/bbox/ymax' : tf .VarLenFeature (tf .float32 ),
47
- 'image/object/class/label' : tf .VarLenFeature (tf .int64 ),
48
- 'image/object/area' : tf .VarLenFeature (tf .float32 ),
49
- 'image/object/is_crowd' : tf .VarLenFeature (tf .int64 ),
50
- 'image/object/difficult' : tf .VarLenFeature (tf .int64 ),
51
- # Instance masks and classes.
52
- 'image/segmentation/object' : tf .VarLenFeature (tf .int64 ),
53
- 'image/segmentation/object/class' : tf .VarLenFeature (tf .int64 )
66
+ 'image/object/bbox/xmin' :
67
+ tf .VarLenFeature (tf .float32 ),
68
+ 'image/object/bbox/xmax' :
69
+ tf .VarLenFeature (tf .float32 ),
70
+ 'image/object/bbox/ymin' :
71
+ tf .VarLenFeature (tf .float32 ),
72
+ 'image/object/bbox/ymax' :
73
+ tf .VarLenFeature (tf .float32 ),
74
+ 'image/object/class/label' :
75
+ tf .VarLenFeature (tf .int64 ),
76
+ 'image/object/class/text' :
77
+ tf .VarLenFeature (tf .string ),
78
+ 'image/object/area' :
79
+ tf .VarLenFeature (tf .float32 ),
80
+ 'image/object/is_crowd' :
81
+ tf .VarLenFeature (tf .int64 ),
82
+ 'image/object/difficult' :
83
+ tf .VarLenFeature (tf .int64 ),
84
+ 'image/object/group_of' :
85
+ tf .VarLenFeature (tf .int64 ),
54
86
}
55
87
self .items_to_handlers = {
56
88
fields .InputDataFields .image : slim_example_decoder .Image (
@@ -65,22 +97,42 @@ def __init__(self):
65
97
fields .InputDataFields .groundtruth_boxes : (
66
98
slim_example_decoder .BoundingBox (
67
99
['ymin' , 'xmin' , 'ymax' , 'xmax' ], 'image/object/bbox/' )),
68
- fields .InputDataFields .groundtruth_classes : (
69
- slim_example_decoder .Tensor ('image/object/class/label' )),
70
100
fields .InputDataFields .groundtruth_area : slim_example_decoder .Tensor (
71
101
'image/object/area' ),
72
102
fields .InputDataFields .groundtruth_is_crowd : (
73
103
slim_example_decoder .Tensor ('image/object/is_crowd' )),
74
104
fields .InputDataFields .groundtruth_difficult : (
75
105
slim_example_decoder .Tensor ('image/object/difficult' )),
76
- # Instance masks and classes.
77
- fields .InputDataFields .groundtruth_instance_masks : (
78
- slim_example_decoder .ItemHandlerCallback (
79
- ['image/segmentation/object' , 'image/height' , 'image/width' ],
80
- self ._reshape_instance_masks )),
81
- fields .InputDataFields .groundtruth_instance_classes : (
82
- slim_example_decoder .Tensor ('image/segmentation/object/class' )),
106
+ fields .InputDataFields .groundtruth_group_of : (
107
+ slim_example_decoder .Tensor ('image/object/group_of' ))
83
108
}
109
+ if load_instance_masks :
110
+ self .keys_to_features ['image/object/mask' ] = tf .VarLenFeature (tf .float32 )
111
+ self .items_to_handlers [
112
+ fields .InputDataFields .groundtruth_instance_masks ] = (
113
+ slim_example_decoder .ItemHandlerCallback (
114
+ ['image/object/mask' , 'image/height' , 'image/width' ],
115
+ self ._reshape_instance_masks ))
116
+ if label_map_proto_file :
117
+ label_map = label_map_util .get_label_map_dict (label_map_proto_file ,
118
+ use_display_name )
119
+ # We use a default_value of -1, but we expect all labels to be contained
120
+ # in the label map.
121
+ table = tf .contrib .lookup .HashTable (
122
+ initializer = tf .contrib .lookup .KeyValueTensorInitializer (
123
+ keys = tf .constant (list (label_map .keys ())),
124
+ values = tf .constant (list (label_map .values ()), dtype = tf .int64 )),
125
+ default_value = - 1 )
126
+ # If the label_map_proto is provided, try to use it in conjunction with
127
+ # the class text, and fall back to a materialized ID.
128
+ label_handler = slim_example_decoder .BackupHandler (
129
+ slim_example_decoder .LookupTensor (
130
+ 'image/object/class/text' , table , default_value = '' ),
131
+ slim_example_decoder .Tensor ('image/object/class/label' ))
132
+ else :
133
+ label_handler = slim_example_decoder .Tensor ('image/object/class/label' )
134
+ self .items_to_handlers [
135
+ fields .InputDataFields .groundtruth_classes ] = label_handler
84
136
85
137
def decode (self , tf_example_string_tensor ):
86
138
"""Decodes serialized tensorflow example and returns a tensor dictionary.
@@ -106,14 +158,14 @@ def decode(self, tf_example_string_tensor):
106
158
[None] containing containing object mask area in pixel squared.
107
159
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
108
160
[None] indicating if the boxes enclose a crowd.
161
+ Optional:
109
162
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
110
163
[None] indicating if the boxes represent `difficult` instances.
164
+ fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
165
+ [None] indicating if the boxes represent `group_of` instances.
111
166
fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
112
167
shape [None, None, None] containing instance masks.
113
- fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
114
- of shape [None] containing classes for the instance masks.
115
168
"""
116
-
117
169
serialized_example = tf .reshape (tf_example_string_tensor , shape = [])
118
170
decoder = slim_example_decoder .TFExampleDecoder (self .keys_to_features ,
119
171
self .items_to_handlers )
@@ -135,13 +187,14 @@ def _reshape_instance_masks(self, keys_to_tensors):
135
187
keys_to_tensors: a dictionary from keys to tensors.
136
188
137
189
Returns:
138
- A 3-D boolean tensor of shape [num_instances, height, width].
190
+ A 3-D float tensor of shape [num_instances, height, width] with values
191
+ in {0, 1}.
139
192
"""
140
- masks = keys_to_tensors ['image/segmentation/object' ]
141
- if isinstance (masks , tf .SparseTensor ):
142
- masks = tf .sparse_tensor_to_dense (masks )
143
193
height = keys_to_tensors ['image/height' ]
144
194
width = keys_to_tensors ['image/width' ]
145
195
to_shape = tf .cast (tf .stack ([- 1 , height , width ]), tf .int32 )
146
-
147
- return tf .cast (tf .reshape (masks , to_shape ), tf .bool )
196
+ masks = keys_to_tensors ['image/object/mask' ]
197
+ if isinstance (masks , tf .SparseTensor ):
198
+ masks = tf .sparse_tensor_to_dense (masks )
199
+ masks = tf .reshape (tf .to_float (tf .greater (masks , 0.0 )), to_shape )
200
+ return tf .cast (masks , tf .float32 )
0 commit comments