Training a PyTorch semantic segmentation model on your own data is surprisingly less about the PyTorch nn.Module and more about meticulously preparing your datasets and understanding how the loss function interacts with your specific pixel-wise labels.
Let’s get a model running. Imagine we have a dataset of satellite images where we want to classify each pixel as either "land" or "water."
First, we need a Dataset class. This is where the magic of loading and transforming your images and masks happens.
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
import os
class SatelliteDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = sorted([os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('.png', '.jpg', '.jpeg'))])
self.masks = sorted([os.path.join(mask_dir, msk) for msk in os.listdir(mask_dir) if msk.endswith(('.png', '.jpg', '.jpeg'))])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
mask_path = self.masks[idx]
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # 'L' for grayscale, assuming mask is single channel
if self.transform:
image, mask = self.transform(image, mask)
return image, mask
# Define transformations
# We need to ensure image and mask are transformed identically
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, mask):
for t in self.transforms:
image = t(image)
mask = t(mask)
return image, mask
# Example transformations
data_transform = Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Instantiate the dataset (assuming you have 'data/images' and 'data/masks' directories)
# train_dataset = SatelliteDataset(image_dir='data/images', mask_dir='data/masks', transform=data_transform)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
Now, let’s consider the model. We can use a pre-trained encoder from torchvision.models.segmentation. deeplabv3_resnet101 is a good choice. The key is to modify its classifier head to match the number of classes in your dataset.
import torchvision
from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights
# Load a pre-trained model
weights = DeepLabV3_ResNet101_Weights.DEFAULT
model = torchvision.models.segmentation.deeplabv3_resnet101(weights=weights)
# Get the number of output channels from the current classifier
num_classes_original = model.classifier[4].out_channels
# Assuming we have 2 classes: 0 for background (land) and 1 for foreground (water)
num_classes_custom = 2
# Replace the classifier head
# The classifier is typically a Sequential module, and the last layer is the Conv2d
# For DeepLabV3, it's usually model.classifier[4] which is a Conv2d layer
model.classifier[4] = torch.nn.Conv2d(2048, num_classes_custom, kernel_size=1)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
The training loop itself is standard PyTorch, but the loss function is critical for semantic segmentation. nn.CrossEntropyLoss is the go-to, and it expects raw logits from the model. Crucially, it expects the target mask to have class indices, not one-hot encoded vectors.
import torch.optim as optim
import torch.nn as nn
# Loss function
# Ignore index 0 if it represents padding or background you don't want to penalize
criterion = nn.CrossEntropyLoss(ignore_index=0) # Adjust ignore_index if needed
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# --- Training Loop Snippet ---
# Assuming train_loader is defined as above
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
images = images.to(device)
# Masks need to be LongTensor for CrossEntropyLoss and match image spatial dims
# Also, ensure class indices are correct (e.g., 0 for land, 1 for water)
masks = masks.to(device).long() # Ensure it's LongTensor
optimizer.zero_grad()
outputs = model(images)['out'] # DeepLabV3 returns a dict with 'out'
# outputs shape: (batch_size, num_classes_custom, height, width)
# masks shape: (batch_size, height, width)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(train_dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
The nn.CrossEntropyLoss combines LogSoftmax and NLLLoss internally. When you pass raw logits, it applies the softmax. For semantic segmentation, it expects the target to be a tensor of class indices. If your masks are saved as images where pixel value 0 is class 0, 1 is class 1, etc., then Image.open(mask_path).convert("L") and then converting to torch.long will work. The ignore_index parameter is vital if certain pixel values in your mask should not contribute to the loss (e.g., a padding value).
The most surprising thing about custom data training is that your Dataset class often becomes more complex than your model definition, handling intricate file loading, multiple image augmentations that must be applied identically to images and masks, and ensuring data types and shapes align perfectly for PyTorch’s loss functions.
The core challenge lies in data preparation: ensuring your masks are perfectly aligned with your images, that their pixel values correctly map to your defined classes, and that your transformations (resizing, cropping, augmentation) are applied consistently. If your masks are not single-channel grayscale images with direct class index mappings, you’ll need to write custom logic within __getitem__ to convert them.
When you evaluate your model, you’ll typically use metrics like Mean Intersection over Union (mIoU). You’ll need to convert the model’s output logits to class predictions (using torch.argmax(outputs, dim=1)) and then compare these predictions against your ground truth masks, carefully handling the ignore_index if used during training.
The next hurdle you’ll face is optimizing performance and dealing with class imbalance, often requiring techniques like weighted loss functions or over/under-sampling strategies.