Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3237c08

Browse files
committed
add NASnet feature extractor
1 parent c839310 commit 3237c08

File tree

5 files changed

+501
-0
lines changed

5 files changed

+501
-0
lines changed

research/object_detection/builders/model_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747

4848
# A map of names to Faster R-CNN feature extractors.
4949
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
50+
'faster_rcnn_nas':
51+
frcnn_nas.FasterRCNNNASFeatureExtractor,
5052
'faster_rcnn_inception_resnet_v2':
5153
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
5254
'faster_rcnn_inception_v2':

research/object_detection/builders/model_builder_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from object_detection.meta_architectures import ssd_meta_arch
2525
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
2626
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
27+
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
2728
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
2829
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
2930
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
@@ -412,6 +413,73 @@ def test_create_faster_rcnn_resnet101_with_mask_prediction_enabled(self):
412413
model = model_builder.build(model_proto, is_training=True)
413414
self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)
414415

416+
def test_create_faster_rcnn_nas_model_from_config(self):
417+
model_text_proto = """
418+
faster_rcnn {
419+
num_classes: 3
420+
image_resizer {
421+
keep_aspect_ratio_resizer {
422+
min_dimension: 600
423+
max_dimension: 1024
424+
}
425+
}
426+
feature_extractor {
427+
type: 'faster_rcnn_nas'
428+
}
429+
first_stage_anchor_generator {
430+
grid_anchor_generator {
431+
scales: [0.25, 0.5, 1.0, 2.0]
432+
aspect_ratios: [0.5, 1.0, 2.0]
433+
height_stride: 16
434+
width_stride: 16
435+
}
436+
}
437+
first_stage_box_predictor_conv_hyperparams {
438+
regularizer {
439+
l2_regularizer {
440+
}
441+
}
442+
initializer {
443+
truncated_normal_initializer {
444+
}
445+
}
446+
}
447+
initial_crop_size: 17
448+
maxpool_kernel_size: 1
449+
maxpool_stride: 1
450+
second_stage_box_predictor {
451+
mask_rcnn_box_predictor {
452+
fc_hyperparams {
453+
op: FC
454+
regularizer {
455+
l2_regularizer {
456+
}
457+
}
458+
initializer {
459+
truncated_normal_initializer {
460+
}
461+
}
462+
}
463+
}
464+
}
465+
second_stage_post_processing {
466+
batch_non_max_suppression {
467+
score_threshold: 0.01
468+
iou_threshold: 0.6
469+
max_detections_per_class: 100
470+
max_total_detections: 300
471+
}
472+
score_converter: SOFTMAX
473+
}
474+
}"""
475+
model_proto = model_pb2.DetectionModel()
476+
text_format.Merge(model_text_proto, model_proto)
477+
model = model_builder.build(model_proto, is_training=True)
478+
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
479+
self.assertIsInstance(
480+
model._feature_extractor,
481+
frcnn_nas.FasterRCNNNASFeatureExtractor)
482+
415483
def test_create_faster_rcnn_inception_resnet_v2_model_from_config(self):
416484
model_text_proto = """
417485
faster_rcnn {

research/object_detection/models/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,29 @@ py_test(
135135
],
136136
)
137137

138+
py_test(
139+
name = "faster_rcnn_nas_feature_extractor_test",
140+
srcs = [
141+
"faster_rcnn_nas_feature_extractor_test.py",
142+
],
143+
deps = [
144+
":faster_rcnn_nas_feature_extractor",
145+
"//tensorflow",
146+
],
147+
)
148+
149+
py_library(
150+
name = "faster_rcnn_nas_feature_extractor",
151+
srcs = [
152+
"faster_rcnn_nas_feature_extractor.py",
153+
],
154+
deps = [
155+
"//tensorflow",
156+
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
157+
"//tensorflow_models/slim:nasnet",
158+
],
159+
)
160+
138161
py_library(
139162
name = "faster_rcnn_inception_resnet_v2_feature_extractor",
140163
srcs = [

0 commit comments

Comments
 (0)