-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Hi, I'm trying to train PoET with my own dataset.
And I've already trained mask-rcnn in same version with docker.
When I ran inference using the Mask R-CNN model I trained, it performed well(high scores, almost 1). but in your code’s MaskRCNNBackbone forward method the scores are coming out too low(under 0.4).
How did you train your backbone? Could I get the code you trained it?
Here is my code.
# --- First: Inference code
import torch
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.transforms import functional as F
from torchvision.transforms import ToTensor
from PIL import Image
from collections import OrderedDict
from typing import Dict, List, Optional
from torchvision.models.detection.rpn import AnchorGenerator, concat_box_prediction_layers
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection._utils import overwrite_eps
class MaskRCNNBackbone(MaskRCNN):
def __init__(
self,
n_classes: int = 2,
return_interm_layers=True,
train_backbone=False
):
backbone = resnet_fpn_backbone('resnet50', pretrained=False)
self.train_backbone = train_backbone
self.return_layers = ['2', '3', 'pool']
self.strides = [8, 16, 32]
self.num_channels = [256, 256, 256]
else:
self.return_layers = ['pool']
self.strides = [256]
self.num_channels = [256]
super().__init__(backbone=backbone,
num_classes=n_classes)
def forward(self, tensor_list: ImageList):
image_sizes = [img.shape[-2:] for img in tensor_list.tensors]
features = self.backbone(tensor_list.tensors)
feature_maps = list(features.values())
objectness, pred_bbox_deltas = self.rpn.head(feature_maps)
grid_sizes = [f.shape[-2:] for f in feature_maps]
image_size = tensor_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [
[torch.tensor(image_size[0]//g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1]//g[1], dtype=torch.int64, device=device)]
for g in grid_sizes
]
self.rpn.anchor_generator.set_cell_anchors(dtype, device)
all_anchors = self.rpn.anchor_generator.grid_anchors(grid_sizes, strides)
anchors = [
torch.cat(all_anchors) for _ in range(len(tensor_list.tensors))
]
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
proposals = self.rpn.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(len(anchors), -1, 4)
boxes, scores = self.rpn.filter_proposals(proposals, objectness, image_sizes, [a.shape[0] for a in all_anchors])
detections, _ = self.roi_heads(features, boxes, image_sizes)
predictions = []
for det in detections:
preds = []
for c, lbl in enumerate(det["labels"]):
b = det["boxes"][c]
score = det["scores"][c]
preds.append(torch.hstack((b, score.unsqueeze(0), lbl.unsqueeze(0))))
predictions.append(torch.stack(preds) if preds else None)
out: Dict[str, ImageList] = {}
for lvl in self.return_layers:
feat = features[lvl]
mask = tensor_list.mask
resized_mask = F.interpolate(mask[None].float(), size=feat.shape[-2:])[0].to(torch.bool)
out[lvl] = ImageList(feat, resized_mask)
return predictions, out
if __name__ == "__main__":
CHECKPOINT_PATH = "/output/maskrcnn/maskrcnn_epoch_015.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MaskRCNNBackbone(n_classes=2)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
model.to(DEVICE).eval()
img_path = "/data/test_all/rgb_sample/009000.png"
image = Image.open(img_path).convert("RGB")
img_tensor = ToTensor()(image).to(DEVICE)
image_list, _ = model.transform([img_tensor])
image_list = image_list.to(DEVICE)
with torch.no_grad():
detections, features = model(image_list)
print("Detections per image:")
for det in detections:
print(det)
print("\nIntermediate feature maps:")
for lvl, img_list in features.items():
print(f"Level {lvl}: tensor shape = {img_list.tensors.shape}")
# --- Second: Training code
import os
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import json
import torchvision.transforms as T
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/data/maskrcnn_data/syndata0402_topview"
OUTPUT_DIR = "/output"
img_dir = os.path.join(DATA_DIR, "rgb")
mask_dir = os.path.join(DATA_DIR, "binary")
annotation_file = os.path.join(DATA_DIR, "annotations.json")
batch_size = 8
learning_rate = 0.005
num_epochs = 30
num_classes = 2
class CustomDataset(Dataset):
def __init__(self, image_dir, mask_dir, annotation_file, transforms=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transforms = transforms
with open(annotation_file) as f:
self.annotations = json.load(f)
def __getitem__(self, idx):
ann = self.annotations[idx]
# Load RGB image
img_path = os.path.join(self.image_dir, ann["file_name"])
img = Image.open(img_path).convert("RGB")
# Load binary mask from file
mask_path = os.path.join(self.mask_dir, ann["mask_file"])
mask = Image.open(mask_path).convert("L") # to grayscale
mask = (np.array(mask) > 128).astype(np.uint8) # binarize
masks = torch.tensor(mask[np.newaxis, ...]) # [1, H, W]
# Load bbox and label
boxes = torch.as_tensor(ann["boxes"], dtype=torch.float32) # [[x1, y1, x2, y2]]
labels = torch.as_tensor(ann["labels"], dtype=torch.int64) # [1]
target = {
"boxes": boxes,
"labels": labels,
"masks": masks,
"image_id": torch.tensor([idx])
}
if self.transforms:
img = self.transforms(img)
return img, target
def __len__(self):
return len(self.annotations)
transforms = T.Compose([
T.ToTensor()
])
dataset = CustomDataset(
image_dir=img_dir,
mask_dir=mask_dir,
annotation_file=annotation_file,
transforms=transforms
)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
model = maskrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes)
model.to(DEVICE)
model.train()
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=0.0005)
loss_history = []
print("Training Start!!")
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, targets in data_loader:
images = [img.to(DEVICE) for img in images]
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
epoch_loss += losses.item()
avg_loss = epoch_loss / len(data_loader)
loss_history.append(avg_loss)
print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
if (epoch + 1) % 5 == 0:
ckpt_path = f"/output/maskrcnn_epoch_{epoch+1:03}.pth"
torch.save(model.state_dict(), ckpt_path)
print(f"✅ Saved checkpoint: {ckpt_path}")
with open("/output/loss_history.json", "w") as f:
json.dump(loss_history, f)
print("📈 Loss history saved to /output/loss_history.json")
Metadata
Metadata
Assignees
Labels
No labels