Autograd, PyTorch’s automatic differentiation engine, is surprisingly flexible, allowing you to define custom backward passes for your operations, not just rely on the built-in ones.
Let’s see it in action. Imagine we have a simple operation, my_relu, that’s just a standard ReLU but we want to handle its backward pass manually.
import torch
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# Save the input tensor for the backward pass
ctx.save_for_backward(input)
# Apply the ReLU activation element-wise
output = input.clamp_min(0)
return output
@staticmethod
def backward(ctx, grad_output):
# Retrieve the saved input tensor
input, = ctx.saved_tensors
# Create a mask where input is positive
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# Instantiate the custom function
my_relu_fn = MyReLU.apply
# Test the custom function
x = torch.randn(5, requires_grad=True)
y = my_relu_fn(x)
z = y.mean()
z.backward()
print("Input x:", x)
print("Output y (my_relu):", y)
print("Gradient of z w.r.t x:", x.grad)
# Compare with built-in ReLU
x_builtin = torch.randn(5, requires_grad=True)
y_builtin = torch.relu(x_builtin)
z_builtin = y_builtin.mean()
z_builtin.backward()
print("\nInput x_builtin:", x_builtin)
print("Output y_builtin (torch.relu):", y_builtin)
print("Gradient of z_builtin w.r.t x_builtin:", x_builtin.grad)
This code defines MyReLU as a subclass of torch.autograd.Function. The forward method computes the element-wise maximum of the input and 0, and crucially, saves the input tensor using ctx.save_for_backward(input). This saved tensor is essential for the backward pass. The backward method receives grad_output (the gradient of the final loss with respect to the output of this operation) and uses the saved input to compute the gradient with respect to the input. For ReLU, this means the gradient is grad_output where the input was positive, and 0 otherwise. The MyReLU.apply method is the way you invoke your custom function within a PyTorch computation graph.
The core problem autograd solves is making neural network training feasible by automatically computing gradients. When you define a custom backward function, you’re essentially telling PyTorch how to "undo" your operation for gradient calculation. This is powerful because it allows you to implement novel operations or layers for which PyTorch doesn’t have a built-in gradient implementation, or where you need a more numerically stable or efficient gradient. The ctx object is a context manager that persists between the forward and backward passes, allowing you to store information needed for the gradient calculation. Anything saved in ctx is automatically managed by autograd.
The grad_input = grad_output.clone() line followed by grad_input[input < 0] = 0 is the heart of the ReLU backward pass. It ensures that gradients only flow through the "active" parts of the ReLU function (where the input was positive). If the input was negative, that part of the computation was effectively a zero-slope path, so its gradient should be zero.
The real magic of custom autograd functions often lies in handling numerical stability. For example, in operations like log(sigmoid(x)), directly computing sigmoid(x) can lead to underflow for large negative x, and log(0) leads to overflow. A custom backward function can rewrite the gradient computation using algebraic identities to avoid these intermediate unstable values, often by working with log(sigmoid(x)) in its log-domain form or using x - sigmoid(x) for the gradient of log(1 + exp(x)).
When implementing a custom backward function, you need to return gradients for all inputs to the forward function, in the same order. If an input to forward doesn’t require a gradient (e.g., it’s not a tensor with requires_grad=True or it’s a non-tensor argument), you should return None for its corresponding gradient. For example, if forward took (ctx, input1, input2, constant_val), backward would need to return (grad_input1, grad_input2, None).
The next step in understanding autograd’s flexibility is often realizing that you can define custom backward passes for operations that are not even differentiable in the traditional sense, by defining a "straight-through estimator" or a surrogate gradient.