You can create a custom loss function in PyTorch by subclassing torch.nn.Module and implementing the forward method.

Here’s a simple example of a custom Mean Squared Error (MSE) loss:

import torch
import torch.nn as nn

class MyMSELoss(nn.Module):
    def __init__(self):
        super(MyMSELoss, self).__init__()

    def forward(self, input, target):
        # Ensure input and target have the same shape
        if input.shape != target.shape:
            raise ValueError("Input and target must have the same shape.")
        
        # Calculate squared difference
        squared_diff = (input - target) ** 2
        
        # Calculate mean of squared differences
        loss = torch.mean(squared_diff)
        
        return loss

# Example usage:
input_tensor = torch.randn(3, 4, requires_grad=True)
target_tensor = torch.randn(3, 4)

criterion = MyMSELoss()
loss = criterion(input_tensor, target_tensor)
loss.backward() # Compute gradients

print(f"Input tensor:\n{input_tensor}")
print(f"Target tensor:\n{target_tensor}")
print(f"Calculated loss: {loss.item()}")
print(f"Gradient of input:\n{input_tensor.grad}")

This MyMSELoss class defines a loss function that computes the mean squared error between the input and target tensors. The forward method contains the core logic: it first checks if the input and target tensors have compatible shapes, then computes the element-wise squared difference, and finally calculates the mean of these squared differences. The .item() method is used to get the scalar value of the loss, and .backward() is called on the loss to trigger backpropagation and compute gradients for the input_tensor.

Let’s see this in action with a slightly more complex scenario, perhaps a weighted MSE where we want to penalize errors on certain elements more heavily.

import torch
import torch.nn as nn

class WeightedMSELoss(nn.Module):
    def __init__(self, weight=None):
        super(WeightedMSELoss, self).__init__()
        # The weight can be a tensor or a scalar.
        # If it's a tensor, it must be broadcastable to the input shape.
        self.weight = weight

    def forward(self, input, target):
        if input.shape != target.shape:
            raise ValueError("Input and target must have the same shape.")
        
        squared_diff = (input - target) ** 2
        
        if self.weight is not None:
            # Ensure weight is a tensor and on the same device as input
            if not isinstance(self.weight, torch.Tensor):
                self.weight = torch.tensor(self.weight, device=input.device, dtype=input.dtype)
            
            # Broadcast weight if necessary
            # This assumes self.weight is either a scalar or has dimensions that can be broadcasted.
            # For more complex broadcasting, you might need explicit reshaping.
            try:
                weighted_squared_diff = squared_diff * self.weight
            except RuntimeError as e:
                raise RuntimeError(f"Weight broadcasting failed. Input shape: {input.shape}, Weight shape: {self.weight.shape}. Original error: {e}")
        else:
            weighted_squared_diff = squared_diff
            
        loss = torch.mean(weighted_squared_diff)
        return loss

# Example with a weight tensor
input_tensor_w = torch.randn(2, 3, requires_grad=True)
target_tensor_w = torch.randn(2, 3)
# Let's heavily penalize the first row, second column
weight_tensor = torch.ones(2, 3)
weight_tensor[0, 1] = 10.0 

criterion_w = WeightedMSELoss(weight=weight_tensor)
loss_w = criterion_w(input_tensor_w, target_tensor_w)
loss_w.backward()

print("\n--- Weighted MSE Example ---")
print(f"Input tensor:\n{input_tensor_w}")
print(f"Target tensor:\n{target_tensor_w}")
print(f"Weight tensor:\n{weight_tensor}")
print(f"Calculated weighted loss: {loss_w.item()}")
print(f"Gradient of input:\n{input_tensor_w.grad}")

# Example with a scalar weight
criterion_scalar_w = WeightedMSELoss(weight=5.0)
loss_scalar_w = criterion_scalar_w(input_tensor_w, target_tensor_w)
print(f"\nCalculated scalar-weighted loss: {loss_scalar_w.item()}")

In this WeightedMSELoss, the __init__ method accepts an optional weight argument. In forward, if a weight is provided, it’s multiplied element-wise with the squared_diff. The code includes a check to ensure the weight is a torch.Tensor and resides on the same device and with the same data type as the input tensor, which is crucial for avoiding runtime errors. It also includes a try-except block to catch potential broadcasting issues if the weight tensor’s shape isn’t compatible with the input. The scalar weight example shows how the weight argument can be a simple float, which PyTorch automatically broadcasts.

The key to making custom loss functions work seamlessly with PyTorch’s autograd system is to use PyTorch tensor operations within the forward method. Operations like +, -, *, /, **, torch.mean(), torch.sum(), torch.log(), torch.exp(), etc., all have defined backward passes. If you were to use a standard Python sum() or a NumPy operation on tensors that require gradients, you would break the computation graph, and loss.backward() would fail or produce incorrect gradients.

Consider a scenario where your loss function needs to incorporate a regularization term, like L2 regularization. You’d typically add this to the output of your primary loss calculation.

import torch
import torch.nn as nn

class MSEWithL2(nn.Module):
    def __init__(self, model, lambda_l2=0.001):
        super(MSEWithL2, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.model = model # Store the model to access its parameters
        self.lambda_l2 = lambda_l2

    def forward(self, input, target):
        # Calculate the primary MSE loss
        mse = self.mse_loss(input, target)
        
        # Calculate L2 regularization term
        l2_reg = torch.tensor(0., device=input.device)
        for param in self.model.parameters():
            # Only consider parameters that require gradients
            if param.requires_grad:
                l2_reg += torch.norm(param, p=2) # L2 norm of the parameter
        
        # Combine MSE loss with L2 regularization
        total_loss = mse + self.lambda_l2 * l2_reg
        
        return total_loss

# Example usage with a dummy model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5) # A simple layer

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
input_data = torch.randn(4, 10, requires_grad=True)
target_data = torch.randn(4, 5)

# Instantiate the custom loss
criterion_l2 = MSEWithL2(model=model, lambda_l2=0.01)

# Forward pass through the model
output_data = model(input_data)

# Calculate the custom loss
loss_l2 = criterion_l2(output_data, target_data)
loss_l2.backward()

print("\n--- MSE with L2 Regularization Example ---")
print(f"Model's linear layer weight requires grad: {model.linear.weight.requires_grad}")
print(f"MSE Loss component (approx): {criterion_l2.mse_loss(output_data, target_data).item():.4f}")
print(f"L2 Regularization component (approx): {criterion_l2.lambda_l2 * torch.norm(model.linear.weight, p=2).item():.4f}")
print(f"Total combined loss: {loss_l2.item():.4f}")
print(f"Gradient of input data:\n{input_data.grad[:2]}") # Showing first 2 rows
print(f"Gradient of model weight:\n{model.linear.weight.grad[:2, :2]}") # Showing first 2x2

In MSEWithL2, we instantiate a standard nn.MSELoss within our custom loss. The __init__ takes the model itself as an argument, which is a common pattern for losses that need access to model parameters (like regularization). In forward, we first compute the MSE, then iterate through self.model.parameters(). For each parameter that requires gradients, we compute its L2 norm (torch.norm(param, p=2)) and add it to l2_reg. Finally, the total_loss is the sum of the MSE and the scaled L2 regularization term. This demonstrates how you can combine existing PyTorch modules and custom logic.

A common pitfall is forgetting to handle the device and dtype of tensors, especially when creating new tensors within your loss function or when using weights. If your model is on a GPU and your loss function creates a tensor on the CPU, you’ll get a RuntimeError about mismatched devices. Always ensure that any new tensors you create (like the initial l2_reg = torch.tensor(0.) or self.weight = torch.tensor(...)) are moved to the same device as the input tensors, typically by using device=input.device or device=target.device. Similarly, dtype should match.

The next thing you’ll likely encounter is needing to implement losses that are not differentiable everywhere, or that require more complex mathematical operations not directly available as standard PyTorch functions.

Want structured learning?

Take the full Pytorch course →