Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ed4e22b

Browse files
authored
Merge pull request tensorflow#3973 from pkulzc/master
Object detection internal changes
2 parents cac90a0 + 13b89b9 commit ed4e22b

File tree

61 files changed

+2808
-1218
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2808
-1218
lines changed

research/object_detection/builders/box_predictor_builder.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
4949

5050
if box_predictor_oneof == 'convolutional_box_predictor':
5151
conv_box_predictor = box_predictor_config.convolutional_box_predictor
52-
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams,
53-
is_training)
52+
conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
53+
is_training)
5454
box_predictor_object = box_predictor.ConvolutionalBoxPredictor(
5555
is_training=is_training,
5656
num_classes=num_classes,
57-
conv_hyperparams=conv_hyperparams,
57+
conv_hyperparams_fn=conv_hyperparams_fn,
5858
min_depth=conv_box_predictor.min_depth,
5959
max_depth=conv_box_predictor.max_depth,
6060
num_layers_before_predictor=(conv_box_predictor.
@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
7373
if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
7474
conv_box_predictor = (box_predictor_config.
7575
weight_shared_convolutional_box_predictor)
76-
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams,
77-
is_training)
76+
conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
77+
is_training)
7878
box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor(
7979
is_training=is_training,
8080
num_classes=num_classes,
81-
conv_hyperparams=conv_hyperparams,
81+
conv_hyperparams_fn=conv_hyperparams_fn,
8282
depth=conv_box_predictor.depth,
8383
num_layers_before_predictor=(conv_box_predictor.
8484
num_layers_before_predictor),
@@ -90,38 +90,40 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
9090

9191
if box_predictor_oneof == 'mask_rcnn_box_predictor':
9292
mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor
93-
fc_hyperparams = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
94-
is_training)
95-
conv_hyperparams = None
93+
fc_hyperparams_fn = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
94+
is_training)
95+
conv_hyperparams_fn = None
9696
if mask_rcnn_box_predictor.HasField('conv_hyperparams'):
97-
conv_hyperparams = argscope_fn(mask_rcnn_box_predictor.conv_hyperparams,
98-
is_training)
97+
conv_hyperparams_fn = argscope_fn(
98+
mask_rcnn_box_predictor.conv_hyperparams, is_training)
9999
box_predictor_object = box_predictor.MaskRCNNBoxPredictor(
100100
is_training=is_training,
101101
num_classes=num_classes,
102-
fc_hyperparams=fc_hyperparams,
102+
fc_hyperparams_fn=fc_hyperparams_fn,
103103
use_dropout=mask_rcnn_box_predictor.use_dropout,
104104
dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability,
105105
box_code_size=mask_rcnn_box_predictor.box_code_size,
106-
conv_hyperparams=conv_hyperparams,
106+
conv_hyperparams_fn=conv_hyperparams_fn,
107107
predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks,
108108
mask_height=mask_rcnn_box_predictor.mask_height,
109109
mask_width=mask_rcnn_box_predictor.mask_width,
110110
mask_prediction_num_conv_layers=(
111111
mask_rcnn_box_predictor.mask_prediction_num_conv_layers),
112112
mask_prediction_conv_depth=(
113113
mask_rcnn_box_predictor.mask_prediction_conv_depth),
114+
masks_are_class_agnostic=(
115+
mask_rcnn_box_predictor.masks_are_class_agnostic),
114116
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
115117
return box_predictor_object
116118

117119
if box_predictor_oneof == 'rfcn_box_predictor':
118120
rfcn_box_predictor = box_predictor_config.rfcn_box_predictor
119-
conv_hyperparams = argscope_fn(rfcn_box_predictor.conv_hyperparams,
120-
is_training)
121+
conv_hyperparams_fn = argscope_fn(rfcn_box_predictor.conv_hyperparams,
122+
is_training)
121123
box_predictor_object = box_predictor.RfcnBoxPredictor(
122124
is_training=is_training,
123125
num_classes=num_classes,
124-
conv_hyperparams=conv_hyperparams,
126+
conv_hyperparams_fn=conv_hyperparams_fn,
125127
crop_size=[rfcn_box_predictor.crop_height,
126128
rfcn_box_predictor.crop_width],
127129
num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height,

research/object_detection/builders/box_predictor_builder_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training):
5454
box_predictor_config=box_predictor_proto,
5555
is_training=False,
5656
num_classes=10)
57-
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams
57+
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
5858
self.assertAlmostEqual((hyperparams_proto.regularizer.
5959
l1_regularizer.weight),
6060
(conv_hyperparams_actual.regularizer.l1_regularizer.
@@ -183,7 +183,7 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training):
183183
box_predictor_config=box_predictor_proto,
184184
is_training=False,
185185
num_classes=10)
186-
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams
186+
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
187187
self.assertAlmostEqual((hyperparams_proto.regularizer.
188188
l1_regularizer.weight),
189189
(conv_hyperparams_actual.regularizer.l1_regularizer.
@@ -297,7 +297,7 @@ def test_box_predictor_builder_calls_fc_argscope_fn(self):
297297
is_training=False,
298298
num_classes=10)
299299
mock_argscope_fn.assert_called_with(hyperparams_proto, False)
300-
self.assertEqual(box_predictor._fc_hyperparams, 'arg_scope')
300+
self.assertEqual(box_predictor._fc_hyperparams_fn, 'arg_scope')
301301

302302
def test_non_default_mask_rcnn_box_predictor(self):
303303
fc_hyperparams_text_proto = """
@@ -417,7 +417,7 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training):
417417
box_predictor_config=box_predictor_proto,
418418
is_training=False,
419419
num_classes=10)
420-
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams
420+
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
421421
self.assertAlmostEqual((hyperparams_proto.regularizer.
422422
l1_regularizer.weight),
423423
(conv_hyperparams_actual.regularizer.l1_regularizer.

research/object_detection/builders/dataset_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
7272
fields.InputDataFields.num_groundtruth_boxes: [],
7373
fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
7474
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
75-
fields.InputDataFields.true_image_shape: [3]
75+
fields.InputDataFields.true_image_shape: [3],
76+
fields.InputDataFields.multiclass_scores: [
77+
max_num_boxes, num_classes + 1 if num_classes is not None else None],
7678
}
7779
# Determine whether groundtruth_classes are integers or one-hot encodings, and
7880
# apply batching appropriately.

research/object_detection/builders/hyperparams_builder.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tensorflow as tf
1818

1919
from object_detection.protos import hyperparams_pb2
20+
from object_detection.utils import context_manager
2021

2122
slim = tf.contrib.slim
2223

@@ -43,7 +44,8 @@ def build(hyperparams_config, is_training):
4344
is_training: Whether the network is in training mode.
4445
4546
Returns:
46-
arg_scope: tf-slim arg_scope containing hyperparameters for ops.
47+
arg_scope_fn: A function to construct tf-slim arg_scope containing
48+
hyperparameters for ops.
4749
4850
Raises:
4951
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
@@ -64,16 +66,21 @@ def build(hyperparams_config, is_training):
6466
if hyperparams_config.HasField('op') and (
6567
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
6668
affected_ops = [slim.fully_connected]
67-
with slim.arg_scope(
68-
affected_ops,
69-
weights_regularizer=_build_regularizer(
70-
hyperparams_config.regularizer),
71-
weights_initializer=_build_initializer(
72-
hyperparams_config.initializer),
73-
activation_fn=_build_activation_fn(hyperparams_config.activation),
74-
normalizer_fn=batch_norm,
75-
normalizer_params=batch_norm_params) as sc:
76-
return sc
69+
def scope_fn():
70+
with (slim.arg_scope([slim.batch_norm], **batch_norm_params)
71+
if batch_norm_params is not None else
72+
context_manager.IdentityContextManager()):
73+
with slim.arg_scope(
74+
affected_ops,
75+
weights_regularizer=_build_regularizer(
76+
hyperparams_config.regularizer),
77+
weights_initializer=_build_initializer(
78+
hyperparams_config.initializer),
79+
activation_fn=_build_activation_fn(hyperparams_config.activation),
80+
normalizer_fn=batch_norm) as sc:
81+
return sc
82+
83+
return scope_fn
7784

7885

7986
def _build_activation_fn(activation_fn):
@@ -167,6 +174,9 @@ def _build_batch_norm_params(batch_norm, is_training):
167174
'center': batch_norm.center,
168175
'scale': batch_norm.scale,
169176
'epsilon': batch_norm.epsilon,
177+
# Remove is_training parameter from here and deprecate it in the proto
178+
# once we refactor Faster RCNN models to set is_training through an outer
179+
# arg_scope in the meta architecture.
170180
'is_training': is_training and batch_norm.train,
171181
}
172182
return batch_norm_params

0 commit comments

Comments
 (0)