@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
49
49
50
50
if box_predictor_oneof == 'convolutional_box_predictor' :
51
51
conv_box_predictor = box_predictor_config .convolutional_box_predictor
52
- conv_hyperparams = argscope_fn (conv_box_predictor .conv_hyperparams ,
53
- is_training )
52
+ conv_hyperparams_fn = argscope_fn (conv_box_predictor .conv_hyperparams ,
53
+ is_training )
54
54
box_predictor_object = box_predictor .ConvolutionalBoxPredictor (
55
55
is_training = is_training ,
56
56
num_classes = num_classes ,
57
- conv_hyperparams = conv_hyperparams ,
57
+ conv_hyperparams_fn = conv_hyperparams_fn ,
58
58
min_depth = conv_box_predictor .min_depth ,
59
59
max_depth = conv_box_predictor .max_depth ,
60
60
num_layers_before_predictor = (conv_box_predictor .
@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
73
73
if box_predictor_oneof == 'weight_shared_convolutional_box_predictor' :
74
74
conv_box_predictor = (box_predictor_config .
75
75
weight_shared_convolutional_box_predictor )
76
- conv_hyperparams = argscope_fn (conv_box_predictor .conv_hyperparams ,
77
- is_training )
76
+ conv_hyperparams_fn = argscope_fn (conv_box_predictor .conv_hyperparams ,
77
+ is_training )
78
78
box_predictor_object = box_predictor .WeightSharedConvolutionalBoxPredictor (
79
79
is_training = is_training ,
80
80
num_classes = num_classes ,
81
- conv_hyperparams = conv_hyperparams ,
81
+ conv_hyperparams_fn = conv_hyperparams_fn ,
82
82
depth = conv_box_predictor .depth ,
83
83
num_layers_before_predictor = (conv_box_predictor .
84
84
num_layers_before_predictor ),
@@ -90,38 +90,40 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
90
90
91
91
if box_predictor_oneof == 'mask_rcnn_box_predictor' :
92
92
mask_rcnn_box_predictor = box_predictor_config .mask_rcnn_box_predictor
93
- fc_hyperparams = argscope_fn (mask_rcnn_box_predictor .fc_hyperparams ,
94
- is_training )
95
- conv_hyperparams = None
93
+ fc_hyperparams_fn = argscope_fn (mask_rcnn_box_predictor .fc_hyperparams ,
94
+ is_training )
95
+ conv_hyperparams_fn = None
96
96
if mask_rcnn_box_predictor .HasField ('conv_hyperparams' ):
97
- conv_hyperparams = argscope_fn (mask_rcnn_box_predictor . conv_hyperparams ,
98
- is_training )
97
+ conv_hyperparams_fn = argscope_fn (
98
+ mask_rcnn_box_predictor . conv_hyperparams , is_training )
99
99
box_predictor_object = box_predictor .MaskRCNNBoxPredictor (
100
100
is_training = is_training ,
101
101
num_classes = num_classes ,
102
- fc_hyperparams = fc_hyperparams ,
102
+ fc_hyperparams_fn = fc_hyperparams_fn ,
103
103
use_dropout = mask_rcnn_box_predictor .use_dropout ,
104
104
dropout_keep_prob = mask_rcnn_box_predictor .dropout_keep_probability ,
105
105
box_code_size = mask_rcnn_box_predictor .box_code_size ,
106
- conv_hyperparams = conv_hyperparams ,
106
+ conv_hyperparams_fn = conv_hyperparams_fn ,
107
107
predict_instance_masks = mask_rcnn_box_predictor .predict_instance_masks ,
108
108
mask_height = mask_rcnn_box_predictor .mask_height ,
109
109
mask_width = mask_rcnn_box_predictor .mask_width ,
110
110
mask_prediction_num_conv_layers = (
111
111
mask_rcnn_box_predictor .mask_prediction_num_conv_layers ),
112
112
mask_prediction_conv_depth = (
113
113
mask_rcnn_box_predictor .mask_prediction_conv_depth ),
114
+ masks_are_class_agnostic = (
115
+ mask_rcnn_box_predictor .masks_are_class_agnostic ),
114
116
predict_keypoints = mask_rcnn_box_predictor .predict_keypoints )
115
117
return box_predictor_object
116
118
117
119
if box_predictor_oneof == 'rfcn_box_predictor' :
118
120
rfcn_box_predictor = box_predictor_config .rfcn_box_predictor
119
- conv_hyperparams = argscope_fn (rfcn_box_predictor .conv_hyperparams ,
120
- is_training )
121
+ conv_hyperparams_fn = argscope_fn (rfcn_box_predictor .conv_hyperparams ,
122
+ is_training )
121
123
box_predictor_object = box_predictor .RfcnBoxPredictor (
122
124
is_training = is_training ,
123
125
num_classes = num_classes ,
124
- conv_hyperparams = conv_hyperparams ,
126
+ conv_hyperparams_fn = conv_hyperparams_fn ,
125
127
crop_size = [rfcn_box_predictor .crop_height ,
126
128
rfcn_box_predictor .crop_width ],
127
129
num_spatial_bins = [rfcn_box_predictor .num_spatial_bins_height ,
0 commit comments