|
24 | 24 | from object_detection.meta_architectures import ssd_meta_arch
|
25 | 25 | from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
|
26 | 26 | from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
|
| 27 | +from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas |
27 | 28 | from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
|
28 | 29 | from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
|
29 | 30 | from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
|
@@ -412,6 +413,73 @@ def test_create_faster_rcnn_resnet101_with_mask_prediction_enabled(self):
|
412 | 413 | model = model_builder.build(model_proto, is_training=True)
|
413 | 414 | self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)
|
414 | 415 |
|
| 416 | + def test_create_faster_rcnn_nas_model_from_config(self): |
| 417 | + model_text_proto = """ |
| 418 | + faster_rcnn { |
| 419 | + num_classes: 3 |
| 420 | + image_resizer { |
| 421 | + keep_aspect_ratio_resizer { |
| 422 | + min_dimension: 600 |
| 423 | + max_dimension: 1024 |
| 424 | + } |
| 425 | + } |
| 426 | + feature_extractor { |
| 427 | + type: 'faster_rcnn_nas' |
| 428 | + } |
| 429 | + first_stage_anchor_generator { |
| 430 | + grid_anchor_generator { |
| 431 | + scales: [0.25, 0.5, 1.0, 2.0] |
| 432 | + aspect_ratios: [0.5, 1.0, 2.0] |
| 433 | + height_stride: 16 |
| 434 | + width_stride: 16 |
| 435 | + } |
| 436 | + } |
| 437 | + first_stage_box_predictor_conv_hyperparams { |
| 438 | + regularizer { |
| 439 | + l2_regularizer { |
| 440 | + } |
| 441 | + } |
| 442 | + initializer { |
| 443 | + truncated_normal_initializer { |
| 444 | + } |
| 445 | + } |
| 446 | + } |
| 447 | + initial_crop_size: 17 |
| 448 | + maxpool_kernel_size: 1 |
| 449 | + maxpool_stride: 1 |
| 450 | + second_stage_box_predictor { |
| 451 | + mask_rcnn_box_predictor { |
| 452 | + fc_hyperparams { |
| 453 | + op: FC |
| 454 | + regularizer { |
| 455 | + l2_regularizer { |
| 456 | + } |
| 457 | + } |
| 458 | + initializer { |
| 459 | + truncated_normal_initializer { |
| 460 | + } |
| 461 | + } |
| 462 | + } |
| 463 | + } |
| 464 | + } |
| 465 | + second_stage_post_processing { |
| 466 | + batch_non_max_suppression { |
| 467 | + score_threshold: 0.01 |
| 468 | + iou_threshold: 0.6 |
| 469 | + max_detections_per_class: 100 |
| 470 | + max_total_detections: 300 |
| 471 | + } |
| 472 | + score_converter: SOFTMAX |
| 473 | + } |
| 474 | + }""" |
| 475 | + model_proto = model_pb2.DetectionModel() |
| 476 | + text_format.Merge(model_text_proto, model_proto) |
| 477 | + model = model_builder.build(model_proto, is_training=True) |
| 478 | + self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch) |
| 479 | + self.assertIsInstance( |
| 480 | + model._feature_extractor, |
| 481 | + frcnn_nas.FasterRCNNNASFeatureExtractor) |
| 482 | + |
415 | 483 | def test_create_faster_rcnn_inception_resnet_v2_model_from_config(self):
|
416 | 484 | model_text_proto = """
|
417 | 485 | faster_rcnn {
|
|
0 commit comments