Contrastive learning in PyTorch, when used to train Siamese networks, fundamentally teaches a model to distinguish between similar and dissimilar data points by pulling representations of similar items closer together in an embedding space and pushing dissimilar items further apart.
Let’s see this in action with a simple example. Imagine we have images of cats and dogs. A Siamese network, with contrastive learning, will be trained such that if you show it two different images of the same cat, their learned representations will be very close. However, if you show it an image of a cat and an image of a dog, their representations will be far apart.
Here’s a simplified PyTorch setup:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import random
# Assume we have a dataset of images
class SiameseDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img1_path = self.image_paths[index]
label1 = self.labels[index]
# Create a positive or negative pair
if random.random() < 0.5: # Positive pair
label2 = label1
while True:
index2 = random.randint(0, len(self.image_paths) - 1)
if self.labels[index2] == label2:
img2_path = self.image_paths[index2]
break
else: # Negative pair
label2 = 1 - label1 # Assuming binary labels 0 and 1
while True:
index2 = random.randint(0, len(self.image_paths) - 1)
if self.labels[index2] == label2:
img2_path = self.image_paths[index2]
break
img1 = Image.open(img1_path).convert('RGB')
img2 = Image.open(img2_path).convert('RGB')
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2, torch.tensor(label1) == torch.tensor(self.labels[index2]) # 1 for same, 0 for different
def __len__(self):
return len(self.image_paths)
# Define the Siamese Network
class SiameseNetwork(nn.Module):
def __init__(self, base_network):
super(SiameseNetwork, self).__init__()
self.base_network = base_network
def forward(self, input1, input2):
output1 = self.base_network(input1)
output2 = self.base_network(input2)
return output1, output2
# Define the Contrastive Loss
class ContrastiveLoss(nn.Module):
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = nn.functional.pairwise_distance(output1, output2)
loss_contrastive = torch.mean(
label * torch.pow(euclidean_distance, 2) +
(1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
)
return loss_contrastive
# --- Training Setup (simplified) ---
# Assume 'train_data' is a list of (image_path, label) tuples
# Preprocess images
data_transform = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load a pre-trained ResNet as the base network
resnet = models.resnet18(pretrained=True)
# Remove the final classification layer
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Identity() # Replace with Identity to get feature vector
# Instantiate the Siamese Network
siamese_net = SiameseNetwork(resnet)
# Instantiate the Contrastive Loss
criterion = ContrastiveLoss(margin=1.0)
# Optimizer
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)
# DataLoader
# Replace with your actual data loading and preparation
# For demonstration:
image_paths = ["path/to/img1.jpg", "path/to/img2.jpg", ...]
labels = [0, 1, ...] # 0 for cat, 1 for dog
dataset = SiameseDataset(image_paths, labels, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Training loop (simplified)
num_epochs = 10
for epoch in range(num_epochs):
for img1, img2, labels in dataloader:
optimizer.zero_grad()
output1, output2 = siamese_net(img1, img2)
loss = criterion(output1, output2, labels.float()) # Ensure labels are float for calculation
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')
The core idea is to create pairs of data points: "positive" pairs are two augmented views of the same original image, and "negative" pairs are augmented views of different original images. The Siamese network consists of two identical subnetworks (often sharing weights) that process each image in a pair independently. The output of these subnetworks are feature vectors (embeddings). The contrastive loss function then penalizes the network if the distance between embeddings of a positive pair is too large, or if the distance between embeddings of a negative pair is too small (less than a predefined margin).
The SiameseDataset is crucial here. It takes your original dataset and dynamically generates these positive and negative pairs. For each item index, it picks another item index2. If labels[index] == labels[index2], it’s a positive pair; otherwise, it’s a negative pair. The ContrastiveLoss then uses these pairs. If label is 1 (positive pair), the loss is proportional to the squared Euclidean distance. If label is 0 (negative pair), the loss is proportional to the squared distance, but only if that distance is less than the margin. If the distance is already greater than the margin, the loss for that negative pair is zero. This encourages dissimilar items to be at least margin distance apart.
The base network, often a pre-trained CNN like ResNet, acts as the feature extractor. By removing its final classification layer and replacing it with nn.Identity(), we obtain the feature vector representation of the image. The SiameseNetwork simply applies this base network to both images in a pair.
The "surprising" aspect for many is that contrastive learning can learn powerful representations without relying on explicit class labels during the self-supervised pre-training phase. It learns about the inherent structure of the data by understanding what makes two instances similar or different. This is incredibly powerful because it can leverage vast amounts of unlabeled data.
The margin parameter in ContrastiveLoss is a hyperparameter that controls how far apart negative pairs should be pushed. A larger margin means a more stringent requirement for separation. The choice of margin often depends on the scale of your embedding space and the nature of your data.
The output of the base network resnet.fc = nn.Identity() gives you a fixed-size feature vector. For a ResNet18, this would typically be a 512-dimensional vector. This vector is the learned representation of the image. You can then use these embeddings for downstream tasks like classification, clustering, or similarity search.
A key detail often overlooked is the augmentation strategy applied to the images. For contrastive learning to work effectively, especially in a self-supervised manner, the augmentations used for generating positive pairs must be strong enough to create perceptually different views of the same underlying object, but not so strong that they destroy the semantic content. Common augmentations include random cropping, resizing, color jittering, flipping, and Gaussian blur. The exact choice and strength of these augmentations significantly impact the quality of learned representations.
The next logical step after training a Siamese network with contrastive learning is to evaluate its learned embeddings on a downstream task, such as linear probing or fine-tuning for classification.