PyTorch’s BatchNorm and LayerNorm are both normalization techniques, but they operate on different axes, leading to fundamentally different use cases.
Let’s see BatchNorm in action. Imagine a batch of images, say 32 images, each 3x224x224 (channels, height, width).
import torch
import torch.nn as nn
# Batch of 32 images, 3 channels, 224x224 resolution
batch_size = 32
channels = 3
height = 224
width = 224
dummy_input = torch.randn(batch_size, channels, height, width)
# BatchNorm applied to this input
bn = nn.BatchNorm2d(channels) # Expects input of shape (N, C, H, W)
output_bn = bn(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"BatchNorm output shape: {output_bn.shape}")
print(f"BatchNorm mean (per channel, across batch): {bn.running_mean.shape}")
print(f"BatchNorm variance (per channel, across batch): {bn.running_var.shape}")
Output:
Input shape: torch.Size([32, 3, 224, 224])
BatchNorm output shape: torch.Size([32, 3, 224, 224])
BatchNorm mean (per channel, across batch): torch.Size([3])
BatchNorm variance (per channel, across batch): torch.Size([3])
BatchNorm normalizes across the batch dimension. For each channel, it calculates the mean and variance across all samples in the current mini-batch and then normalizes the activations for that channel. This means the statistics (mean and variance) are batch-dependent. During training, it uses the mini-batch statistics. During inference, it uses running_mean and running_var that are accumulated over time. This dependency on batch statistics makes BatchNorm highly effective for Convolutional Neural Networks (CNNs) where the distribution of features across different samples in a batch is expected to be similar. It helps stabilize training by reducing internal covariate shift, allowing for higher learning rates and faster convergence. The key is that the statistics are computed per channel, independent of spatial dimensions.
Now, let’s look at LayerNorm. Imagine the same batch of images, but we’re going to apply LayerNorm.
# LayerNorm applied to the same input
ln = nn.LayerNorm([channels, height, width]) # Expects input of shape (N, C, H, W)
output_ln = ln(dummy_input)
print(f"LayerNorm output shape: {output_ln.shape}")
# LayerNorm doesn't maintain running statistics like BatchNorm
# Its statistics are computed per-sample, per-layer
Output:
LayerNorm output shape: torch.Size([32, 3, 224, 224])
LayerNorm normalizes across the feature dimension (all elements within a single sample). For each individual sample in the batch, it computes the mean and variance across all its features (channels, height, and width in this case) and then normalizes that sample. The statistics are sample-dependent and independent of the batch. This makes LayerNorm particularly well-suited for Recurrent Neural Networks (RNNs) and Transformers, where sequence lengths can vary, and the assumption of similar feature distributions across a batch might not hold. In these architectures, the number of features can also be very large, and normalizing across them per sample is more stable than trying to normalize across batches.
The choice hinges on your architecture and data. If you’re working with CNNs and your batch size is sufficiently large (e.g., > 16), BatchNorm is usually the go-to. It leverages the batch dimension for stable statistics. If you’re using RNNs, LSTMs, GRUs, or Transformers, or if you have very small batch sizes, LayerNorm is generally preferred because it normalizes independently for each sample, making it robust to varying sequence lengths and batch sizes.
A common misconception is that BatchNorm is always better for image tasks. While true for many CNNs, if you’re building a Transformer-based vision model, LayerNorm is the standard. The normalization in LayerNorm is computed across the last M dimensions, where M is the length of the normalized_shape argument. For an input of (N, C, H, W), if normalized_shape=(C, H, W), it normalizes across C, H, and W for each N. If normalized_shape=(H, W), it normalizes across H and W for each N and C. This flexibility in specifying normalized_shape is crucial.
The next step after understanding these is to explore how they interact with different activation functions and the impact on gradient flow.