Training a PyTorch object detection model on your own data is surprisingly straightforward once you understand how the torchvision library structures datasets and models.
Let’s say you have a dataset of images and corresponding bounding box annotations for detecting cats and dogs. The goal is to train a Faster R-CNN model to identify these animals in new images.
Here’s a simplified view of what a single data sample looks like internally:
{
'image': torch.Tensor(C, H, W), # A tensor representing the image
'annotation': {
'boxes': torch.Tensor(N, 4), # N bounding boxes, each [x1, y1, x2, y2]
'labels': torch.Tensor(N), # N labels, corresponding to the boxes
'image_id': torch.tensor([0]), # Unique ID for the image
'area': torch.tensor([area]), # Area of each bounding box
'iscrowd': torch.tensor([0] * N) # 0 for single objects, 1 for crowds
}
}
This structure is crucial because it’s what torchvision’s training utilities expect. Your primary task is to convert your existing data (e.g., COCO, Pascal VOC, or even a simple CSV) into this dictionary format.
Setting Up Your Custom Dataset
The core of custom training is the Dataset class. You’ll create a subclass of torch.utils.data.Dataset and implement two key methods: __len__ (returning the total number of samples) and __getitem__ (returning a single sample in the expected dictionary format).
import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image
import os
import xml.etree.ElementTree as ET # For Pascal VOC format
class CustomObjectDetectionDataset(Dataset):
def __init__(self, root, transforms=None):
self.root = root
self.transforms = transforms
# In a real scenario, you'd parse your annotation files here
# For simplicity, let's assume a structure like:
# root/images/image1.jpg
# root/annotations/image1.xml (Pascal VOC format)
self.image_dir = os.path.join(root, "images")
self.annotation_dir = os.path.join(root, "annotations")
self.ids = [os.path.splitext(f)[0] for f in os.listdir(self.image_dir) if f.endswith('.jpg')]
def __getitem__(self, idx):
# Load image
img_path = os.path.join(self.image_dir, f"{self.ids[idx]}.jpg")
img = Image.open(img_path).convert("RGB")
# Load annotation (assuming Pascal VOC XML)
annotation_path = os.path.join(self.annotation_dir, f"{self.ids[idx]}.xml")
boxes, labels = self._parse_voc_xml(annotation_path)
# Convert to tensors
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64) # Object detection labels are 0-indexed
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
iscrowd = torch.zeros((len(boxes),), dtype=torch.int64) # Assuming no crowd annotations
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.ids)
def _parse_voc_xml(self, xml_path):
tree = ET.parse(xml_path)
root = tree.getroot()
boxes = []
labels = []
# Assuming a fixed mapping for simplicity: 'cat': 1, 'dog': 2
label_map = {'cat': 1, 'dog': 2}
for obj in root.findall('object'):
label_name = obj.find('name').text
if label_name in label_map:
bbox = obj.find('bndbox')
xmin = int(bbox.find('xmin').text)
ymin = int(bbox.find('ymin').text)
xmax = int(bbox.find('xmax').text)
ymax = int(bbox.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(label_map[label_name])
return boxes, labels
Transforms for Object Detection
Object detection requires specific transformations that operate on both the image and its annotations. torchvision.transforms provides utilities for this. You’ll often need to combine standard image transformations (like ToTensor, Resize, RandomFlip) with object detection-specific ones.
from torchvision.transforms import functional as F
def get_transform(train):
transforms = []
# Convert image to tensor
transforms.append(ToTensor())
if train:
# Randomly flip the image and boxes
transforms.append(RandomHorizontalFlip(p=0.5))
# For training, the boxes are also tensors.
# For inference, you might not need this.
# This is a simplified example; actual transforms might need custom classes
# if you're doing complex augmentations that affect coordinates.
# torchvision.transforms.RandomResizedCrop, for example, needs careful handling
# for bounding boxes.
# A common practice is to create a custom transform that handles both.
# For simplicity here, we'll assume ToTensor handles box conversion implicitly
# or that subsequent augmentations are handled by the model internally.
# A more robust solution involves a custom transform class that applies
# augmentations to both image and boxes.
# Example of a custom transform that handles boxes:
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, image, target):
if torch.rand(1) < self.p:
img = F.hflip(image)
bbox = target["boxes"]
_, w, _ = image.size()
bbox[:, [0, 2]] = w - bbox[:, [2, 0]] # Flip x coordinates
target["boxes"] = bbox
return image, target
return Compose(transforms)
Model Selection and Training
You can use pre-trained models from torchvision.models.detection and fine-tune them. Faster R-CNN, Mask R-CNN, RetinaNet, and SSD are popular choices.
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# Load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# Get the number of input features for the classifier
num_classes = 3 # 2 classes (cat, dog) + 1 background class
# Replace the classifier with a new one that has num_classes outputs
# This is the part that adapts the model to your specific number of classes.
# The number of output classes for the classifier should be num_classes.
# For Faster R-CNN, the classifier is a FastRCNNPredictor.
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Define device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Create datasets and dataloaders
dataset_train = CustomObjectDetectionDataset('/path/to/your/train/data', get_transform(train=True))
dataset_val = CustomObjectDetectionDataset('/path/to/your/val/data', get_transform(train=False))
# Use the standard PyTorch data loaders
data_loader_train = torch.utils.data.DataLoader(
dataset_train, batch_size=2, shuffle=True, num_workers=4,
collate_fn=lambda batch: tuple(zip(*batch)) # Important for object detection datasets
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, batch_size=2, shuffle=False, num_workers=4,
collate_fn=lambda batch: tuple(zip(*batch))
)
# Define optimizer and learning rate scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
# Training loop (simplified)
num_epochs = 10
for epoch in range(num_epochs):
# Train for one epoch
model.train()
for images, targets in data_loader_train:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Update the learning rate
lr_scheduler.step()
# Validation step (omitted for brevity)
print("Training finished!")
The collate_fn is critical. Standard batching will fail because images and their corresponding targets (bounding boxes, labels) have different numbers of elements. The collate_fn ensures that each batch is a tuple of lists, where each list contains the items for a single image in the batch.
The most surprising thing about custom object detection training is how much of the complexity is abstracted away by torchvision’s Dataset and model APIs. Your primary job is data formatting and ensuring your __getitem__ produces the exact dictionary structure the library expects, including the correct tensor dtypes and shapes for bounding boxes and labels.
Once you have your custom dataset formatted and a suitable Dataset class, you can plug it into the standard DataLoader and use the same training loop structure as provided by torchvision examples. The key is adapting the final roi_heads.box_predictor to your num_classes.
After training, the next hurdle is evaluating your model’s performance using metrics like Mean Average Precision (mAP).