-
Notifications
You must be signed in to change notification settings - Fork 388
Description
I am using mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py, with my own training script. Its my custom dataset which I am using. I have not modified anything in the config file but getting error
import json
import mmcv
import os.path as osp
from mmdet.apis import train_detector, set_random_seed
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
import mmdet
import torch
from mmdet.datasets.builder import DATASETS
class ObjectDetector:
def __init__(self, config_path):
torch.cuda.empty_cache()
with open(config_path, 'r') as file:
self.config = json.load(file)
# Load the default config file
self.cfg = mmcv.Config.fromfile('/home/biosense/Documents/Rana/Ronit/SWIN/Swin-Transformer-Object-Detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py')
# self.cfg = ''
self._update_config()
self._save_config()
def _update_config(self):
# Dataset configuration
self.cfg.dataset_type = 'CocoDataset'
self.cfg.data_root = self.config['train_folder']
self.cfg.data.train.type = self.cfg.dataset_type
self.cfg.data.train.data_root = self.config['train_folder']
self.cfg.data.train.ann_file = self.config['train_json']
self.cfg.data.train.img_prefix = ''
self.cfg.data.val.type = self.cfg.dataset_type
self.cfg.data.val.data_root = self.config['val_folder']
self.cfg.data.val.ann_file = self.config['val_json']
self.cfg.data.val.img_prefix = ''
self.cfg.data.test.type = self.cfg.dataset_type
self.cfg.data.test.data_root = self.config['test_folder']
self.cfg.data.test.ann_file = self.config['test_json']
self.cfg.data.test.img_prefix = ''
# Model settings
self.cfg.model.roi_head.bbox_head.num_classes = self.config['num_classes']
# Training settings
self.cfg.runner.max_epochs = self.config['epochs']
self.cfg.work_dir = './results'
# Optimizer settings
self.cfg.optimizer.lr = self.config['learning_rate']
self.cfg.optimizer.weight_decay = self.config['weight_decay']
# Learning rate scheduler settings
self.cfg.lr_config.policy = 'step'
self.cfg.lr_config.warmup = self.config['warmup']
self.cfg.lr_config.warmup_iters = self.config['warmup_iters']
self.cfg.lr_config.warmup_ratio = self.config['warmup_ratio']
self.cfg.lr_config.step = self.config['step_lr_policy']
# Logging settings
self.cfg.log_config.interval = self.config['log_interval']
# Evaluation settings
self.cfg.evaluation.metric = 'bbox'
self.cfg.evaluation.interval = self.config['evaluation_interval']
# Checkpoint settings
self.cfg.checkpoint_config.interval = self.config['checkpoint_interval']
# Seed for reproducibility
self.cfg.seed = 0
set_random_seed(0, deterministic=False)
self.cfg.gpu_ids = range(1)
print(f'Config:\n{self.cfg.pretty_text}')
def _save_config(self):
# Save the updated config to a file
self.cfg.dump('./updated_config.py')
def train(self):
# Build dataset and model
mmdet.datasets.coco.CocoDataset.CLASSES = ('pt', 'gp')
datasets = [build_dataset(self.cfg.data.train)]
model = build_detector(self.cfg.model, train_cfg=self.cfg.get('train_cfg'), test_cfg=self.cfg.get('test_cfg'))
model.CLASSES = datasets[0].CLASSES
# Create directory for work
mmcv.mkdir_or_exist(osp.abspath(self.cfg.work_dir))
train_detector(model, datasets, self.cfg, distributed=False, validate=True)
if __name__ == '__main__':
config_path = 'config.json'
detector = ObjectDetector(config_path)
detector.train()
File "/home/biosense/Documents/Rana/Ronit/SWIN/Swin-Transformer-Object-Detection/mmdet/datasets/utils.py", line 136, in _check_head
assert module.num_classes == len(dataset.CLASSES),
AssertionError: The num_classes (80) in FCNMaskHead of MMDataParallel does not matches the length of CLASSES 2) in CocoDataset