Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Training with custom dataset #32

@seongin0417

Description

@seongin0417

Hi, I'm trying to train PoET with my own dataset.
And I've already trained mask-rcnn in same version with docker.

When I ran inference using the Mask R-CNN model I trained, it performed well(high scores, almost 1). but in your code’s MaskRCNNBackbone forward method the scores are coming out too low(under 0.4).

How did you train your backbone? Could I get the code you trained it?

Here is my code.

# --- First: Inference code
import torch
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.transforms import functional as F
from torchvision.transforms import ToTensor
from PIL import Image
from collections import OrderedDict
from typing import Dict, List, Optional
from torchvision.models.detection.rpn import AnchorGenerator, concat_box_prediction_layers

from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection._utils import overwrite_eps

class MaskRCNNBackbone(MaskRCNN):
    def __init__(
        self,
        n_classes: int = 2,
        return_interm_layers=True,
        train_backbone=False
    ):
        backbone = resnet_fpn_backbone('resnet50', pretrained=False)
        self.train_backbone = train_backbone

            self.return_layers = ['2', '3', 'pool']
            self.strides = [8, 16, 32]
            self.num_channels = [256, 256, 256]
        else:
            self.return_layers = ['pool']
            self.strides = [256]
            self.num_channels = [256]
        
        super().__init__(backbone=backbone,
                         num_classes=n_classes)

    def forward(self, tensor_list: ImageList):
        image_sizes = [img.shape[-2:] for img in tensor_list.tensors]
        features = self.backbone(tensor_list.tensors)

        feature_maps = list(features.values())
        objectness, pred_bbox_deltas = self.rpn.head(feature_maps)
        grid_sizes = [f.shape[-2:] for f in feature_maps]
        image_size = tensor_list.tensors.shape[-2:]
        dtype, device = feature_maps[0].dtype, feature_maps[0].device

        strides = [
            [torch.tensor(image_size[0]//g[0], dtype=torch.int64, device=device),
             torch.tensor(image_size[1]//g[1], dtype=torch.int64, device=device)]
            for g in grid_sizes
        ]
        self.rpn.anchor_generator.set_cell_anchors(dtype, device)
        all_anchors = self.rpn.anchor_generator.grid_anchors(grid_sizes, strides)
        anchors = [
            torch.cat(all_anchors) for _ in range(len(tensor_list.tensors))
        ]

        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
        proposals = self.rpn.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(len(anchors), -1, 4)
        boxes, scores = self.rpn.filter_proposals(proposals, objectness, image_sizes, [a.shape[0] for a in all_anchors])

        detections, _ = self.roi_heads(features, boxes, image_sizes)

        predictions = []
        for det in detections:
            preds = []
            for c, lbl in enumerate(det["labels"]):
                b = det["boxes"][c]
                score = det["scores"][c]
                preds.append(torch.hstack((b, score.unsqueeze(0), lbl.unsqueeze(0))))
            predictions.append(torch.stack(preds) if preds else None)

        out: Dict[str, ImageList] = {}
        for lvl in self.return_layers:
            feat = features[lvl]
            mask = tensor_list.mask
            resized_mask = F.interpolate(mask[None].float(), size=feat.shape[-2:])[0].to(torch.bool)
            out[lvl] = ImageList(feat, resized_mask)

        return predictions, out

if __name__ == "__main__":
    CHECKPOINT_PATH = "/output/maskrcnn/maskrcnn_epoch_015.pth"

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MaskRCNNBackbone(n_classes=2)
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
    model.to(DEVICE).eval()

    img_path = "/data/test_all/rgb_sample/009000.png"
    image = Image.open(img_path).convert("RGB")
    img_tensor = ToTensor()(image).to(DEVICE)

    image_list, _ = model.transform([img_tensor])
    image_list = image_list.to(DEVICE)

    with torch.no_grad():
        detections, features = model(image_list)

    print("Detections per image:")
    for det in detections:
        print(det)

    print("\nIntermediate feature maps:")
    for lvl, img_list in features.items():
        print(f"Level {lvl}: tensor shape = {img_list.tensors.shape}")

# --- Second: Training code

import os
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import json
import torchvision.transforms as T

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/data/maskrcnn_data/syndata0402_topview"
OUTPUT_DIR = "/output"

img_dir = os.path.join(DATA_DIR, "rgb")
mask_dir = os.path.join(DATA_DIR, "binary")
annotation_file = os.path.join(DATA_DIR, "annotations.json")

batch_size = 8
learning_rate = 0.005
num_epochs = 30

num_classes = 2

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, annotation_file, transforms=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transforms = transforms

        with open(annotation_file) as f:
            self.annotations = json.load(f)

    def __getitem__(self, idx):
        ann = self.annotations[idx]

        # Load RGB image
        img_path = os.path.join(self.image_dir, ann["file_name"])
        img = Image.open(img_path).convert("RGB")

        # Load binary mask from file
        mask_path = os.path.join(self.mask_dir, ann["mask_file"])
        mask = Image.open(mask_path).convert("L")  # to grayscale
        mask = (np.array(mask) > 128).astype(np.uint8)  # binarize
        masks = torch.tensor(mask[np.newaxis, ...])  # [1, H, W]

        # Load bbox and label
        boxes = torch.as_tensor(ann["boxes"], dtype=torch.float32)  # [[x1, y1, x2, y2]]
        labels = torch.as_tensor(ann["labels"], dtype=torch.int64)  # [1]

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([idx])
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.annotations)


transforms = T.Compose([
    T.ToTensor()
])


dataset = CustomDataset(
    image_dir=img_dir,
    mask_dir=mask_dir,
    annotation_file=annotation_file,
    transforms=transforms
)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))


model = maskrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes)
model.to(DEVICE)
model.train()

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=0.0005)


loss_history = []

print("Training Start!!")
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for images, targets in data_loader:
        images = [img.to(DEVICE) for img in images]
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        epoch_loss += losses.item()

    avg_loss = epoch_loss / len(data_loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")

    if (epoch + 1) % 5 == 0:
        ckpt_path = f"/output/maskrcnn_epoch_{epoch+1:03}.pth"
        torch.save(model.state_dict(), ckpt_path)
        print(f"✅ Saved checkpoint: {ckpt_path}")

with open("/output/loss_history.json", "w") as f:
    json.dump(loss_history, f)
print("📈 Loss history saved to /output/loss_history.json")


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions