Knowledge distillation lets you train a smaller, faster "student" model to mimic the behavior of a larger, more powerful "teacher" model, often achieving comparable accuracy with a fraction of the computational cost.
Let’s see it in action. Imagine we have a pre-trained ResNet-50 (our teacher) that’s great at classifying images but too slow for real-time mobile deployment. We want to train a smaller MobileNetV2 (our student) to perform similarly.
First, we need our teacher model, already loaded and ready to go.
import torch
import torchvision.models as models
from collections import OrderedDict
# Load a pre-trained teacher model (e.g., ResNet-50)
teacher_model = models.resnet50(pretrained=True)
teacher_model.eval() # Set to evaluation mode
Now, we define our student model. This will be a MobileNetV2, which is significantly smaller.
# Define a smaller student model (e.g., MobileNetV2)
student_model = models.mobilenet_v2(pretrained=False) # We'll train this from scratch
# Adjust the classifier to match the number of classes if necessary
num_classes = 1000 # Assuming ImageNet classes
student_model.classifier[1] = torch.nn.Linear(student_model.classifier[1].in_features, num_classes)
student_model.train() # Set to training mode
The core idea of distillation is to use the soft targets from the teacher model as training signals for the student. Instead of just using the hard "one-hot" labels (e.g., [0, 0, 1, 0]), we use the probability distribution over all classes predicted by the teacher. This distribution contains richer information about the relationships between classes.
To get these soft targets, we apply a "temperature" to the teacher’s logits before the softmax. A higher temperature smooths out the probability distribution, revealing more nuanced similarities between classes that the teacher has learned.
def distillation_loss(teacher_logits, student_logits, temperature=2.0, alpha=0.1):
"""
Calculates the distillation loss.
Args:
teacher_logits: Logits from the teacher model.
student_logits: Logits from the student model.
temperature: The temperature for softening the logits.
alpha: The weight for the distillation loss.
Returns:
The total loss including distillation and student loss.
"""
# Soften the logits with temperature
soft_teacher_logits = torch.nn.functional.log_softmax(teacher_logits / temperature, dim=1)
soft_student_logits = torch.nn.functional.log_softmax(student_logits / temperature, dim=1)
# Distillation loss (KL Divergence)
distillation_loss = torch.nn.functional.kl_div(soft_student_logits, soft_teacher_logits, reduction='batchmean') * (temperature ** 2)
# Student loss (standard cross-entropy with hard labels)
# Assuming 'labels' are the ground truth one-hot encoded labels
# For simplicity, we'll use a placeholder 'labels' here. In a real scenario,
# you'd pass the actual ground truth labels from your dataset.
# labels = ... (your actual ground truth labels)
# student_loss = torch.nn.functional.cross_entropy(student_logits, labels)
# In this example, we'll focus on the distillation loss component.
# A full training loop would combine this with a standard cross-entropy loss.
# For demonstration, let's assume student_loss is 0 for now.
student_loss = 0 # Placeholder, replace with actual cross-entropy loss
# Total loss is a weighted sum
total_loss = alpha * distillation_loss + (1 - alpha) * student_loss
return total_loss
The alpha parameter controls the trade-off between mimicking the teacher (distillation loss) and learning from the ground truth labels (student loss). A common practice is to set alpha to a value like 0.1 or 0.3.
The training loop would then look something like this:
# Assume 'data_loader' is your PyTorch DataLoader
optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)
num_epochs = 10
temperature = 4.0 # Experiment with this value
alpha = 0.5 # Experiment with this value
for epoch in range(num_epochs):
for images, labels in data_loader:
optimizer.zero_grad()
# Get teacher predictions (without gradient calculation)
with torch.no_grad():
teacher_outputs = teacher_model(images)
# Get student predictions
student_outputs = student_model(images)
# Calculate the combined loss
loss = distillation_loss(teacher_outputs, student_outputs, temperature=temperature, alpha=alpha)
# Backpropagate and update student model
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
The key is that during training, the student model receives both the hard ground truth labels (for the standard cross-entropy loss) and the soft probabilities from the teacher (for the distillation loss). The torch.nn.functional.kl_div measures the difference between the two probability distributions. The reduction='batchmean' averages the divergence over the batch, and multiplying by temperature ** 2 is a common scaling factor to ensure the gradients are well-behaved.
The magic here is that the teacher model, having learned complex feature representations, can guide the student to learn more than just the raw labels. It teaches the student how to distinguish between similar classes, for instance, by assigning small but non-zero probabilities to related classes. This makes the student more robust and often leads to better generalization.
What’s often overlooked is that the choice of temperature is critical and highly dependent on the dataset and the models. Too low a temperature makes the soft targets too close to hard labels, losing the distillation benefit. Too high a temperature can make the distribution so flat that it provides little useful signal. For many tasks, values between 2.0 and 10.0 are a good starting point.
Once trained, you can discard the large teacher model and deploy the much smaller student model.
The next step in model compression after knowledge distillation is often exploring techniques like pruning or quantization to further reduce the model’s size and inference time.