AdamW is often presented as a superior optimizer to Adam, but the real surprise is that the difference often comes down to a subtle but critical implementation detail of weight decay in Adam itself.
Let’s see this in action. Imagine we’re training a simple linear model.
import torch
import torch.nn as nn
import torch.optim as optim
# Dummy data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
# Model
model = nn.Linear(10, 1)
# --- Adam Optimizer ---
# Learning rate and weight decay
lr_adam = 0.001
weight_decay_adam = 0.01
# Instantiate Adam
adam_optimizer = optim.Adam(model.parameters(), lr=lr_adam, weight_decay=weight_decay_adam)
# --- AdamW Optimizer ---
# Learning rate and weight decay (often the same values for comparison)
lr_adamw = 0.001
weight_decay_adamw = 0.01
# Instantiate AdamW
adamw_optimizer = optim.AdamW(model.parameters(), lr=lr_adamw, weight_decay=weight_decay_adamw)
# --- SGD Optimizer ---
lr_sgd = 0.01
momentum_sgd = 0.9
weight_decay_sgd = 0.0001 # SGD's weight decay is usually smaller
# Instantiate SGD
sgd_optimizer = optim.SGD(model.parameters(), lr=lr_sgd, momentum=momentum_sgd, weight_decay=weight_decay_sgd)
# Simulate a few training steps (forward pass, loss, backward pass)
# For demonstration, we'll just zero gradients and then call step()
criterion = nn.MSELoss()
output = model(X)
loss = criterion(output, y)
loss.backward() # Compute gradients
# Take a step with Adam
adam_optimizer.step()
adam_optimizer.zero_grad() # Clear gradients for next step
# Take a step with AdamW
# We need to re-initialize the model or re-load weights if we want a fair comparison
# For this example, we'll assume fresh weights or a fresh model instance for AdamW
model_adamw = nn.Linear(10, 1) # New model instance for clarity
adamw_optimizer_fresh = optim.AdamW(model_adamw.parameters(), lr=lr_adamw, weight_decay=weight_decay_adamw)
output_adamw = model_adamw(X)
loss_adamw = criterion(output_adamw, y)
loss_adamw.backward()
adamw_optimizer_fresh.step()
adamw_optimizer_fresh.zero_grad()
# Take a step with SGD
# Again, new model instance for clarity
model_sgd = nn.Linear(10, 1)
sgd_optimizer_fresh = optim.SGD(model_sgd.parameters(), lr=lr_sgd, momentum=momentum_sgd, weight_decay=weight_decay_sgd)
output_sgd = model_sgd(X)
loss_sgd = criterion(output_sgd, y)
loss_sgd.backward()
sgd_optimizer_fresh.step()
sgd_optimizer_fresh.zero_grad()
print("Simulated training steps for Adam, AdamW, and SGD.")
The core problem Adam and AdamW are trying to solve is adapting the learning rate for each parameter based on its historical gradients. SGD, on the other hand, uses a fixed learning rate (though momentum helps smooth updates).
Here’s the breakdown:
Stochastic Gradient Descent (SGD) This is the simplest. At each step, you calculate the gradient of the loss with respect to your parameters, and then you update the parameters in the opposite direction of the gradient, scaled by the learning rate.
- Update Rule (simplified): $w_{t+1} = w_t - \eta \nabla L(w_t)$
- $w$ is the parameter.
- $\eta$ is the learning rate.
- $\nabla L(w_t)$ is the gradient of the loss $L$ at parameter $w_t$.
- Momentum: Adds a fraction of the previous update vector to the current one, helping to accelerate convergence and overcome small local minima.
- Weight Decay (L2 Regularization): In SGD, weight decay is typically implemented by adding a term to the gradient that’s proportional to the parameter’s current value. This encourages smaller weights.
- Effective Gradient: $\nabla L(w_t) + \lambda w_t$
- $\lambda$ is the weight decay coefficient.
- This effectively shrinks weights towards zero at each step.
Adam (Adaptive Moment Estimation) Adam is more sophisticated. It keeps track of both the first moment (mean) and the second moment (uncentered variance) of the gradients.
- First Moment (m): Like momentum, it tracks the average gradient.
- Second Moment (v): Tracks the average of the squared gradients.
- Update Rule (simplified):
- Update biased first moment estimate: $m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$
- Update biased second moment estimate: $v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$
- Bias-correction: $\hat{m}_t = m_t / (1-\beta_1^t)$, $\hat{v}_t = v_t / (1-\beta_2^t)$
- Parameter update: $w_{t+1} = w_t - \eta \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)$
- $\beta_1, \beta_2$ are decay rates (typically 0.9 and 0.999).
- $g_t$ is the current gradient.
- $\epsilon$ is a small constant for numerical stability.
- Weight Decay in Adam: This is where things get tricky. Standard Adam implementations often incorporate weight decay by adding it to the gradient before calculating the moments ($m$ and $v$). This means the weight decay term is also scaled by the adaptive learning rates derived from $m$ and $v$.
AdamW (Adam with Decoupled Weight Decay) AdamW was introduced to fix what was perceived as an issue with how weight decay was applied in Adam.
- The Problem with Adam’s Weight Decay: When weight decay is added to the gradient before moment estimation, it gets coupled with the adaptive learning rate mechanism. This means that parameters with large historical gradients (and thus larger scaling factors from $\sqrt{v_t}$) will have their weight decay effect diminished, while parameters with small historical gradients will have their weight decay effect amplified. This is often not the desired behavior, as weight decay is usually intended to be a direct penalty on the magnitude of weights.
- AdamW’s Solution: AdamW decouples the weight decay. It applies the weight decay directly to the weights after the Adam update step, or equivalently, it modifies the gradient before the Adam update in a way that precisely corresponds to applying weight decay to the weights.
- Effective Update (conceptually):
- Calculate gradient $g_t$.
- Apply Adam update using $g_t$ to get an intermediate update $\Delta w_{Adam}$.
- Apply weight decay directly: $w_{t+1} = w_t - \eta (\Delta w_{Adam} + \lambda w_t)$.
- Implementation Detail: In
torch.optim.AdamW, theweight_decayparameter is handled by subtractinglr * weight_decay * paramfrom the parameter after the standard Adam update step, effectively achieving the decoupled effect.
- Effective Update (conceptually):
Why AdamW Often Performs Better:
For many tasks, especially with large models and regularization, AdamW leads to better generalization than Adam with the same weight_decay value. This is because AdamW’s decoupled weight decay behaves more like the original L2 regularization intended in SGD, directly penalizing large weights regardless of their gradient history. Adam’s coupled weight decay can be too aggressive on some weights and too weak on others, leading to suboptimal regularization.
The most surprising thing most people don’t realize is that when you set weight_decay to 0.01 in optim.Adam, you are not applying the same amount of weight decay as you would with weight_decay to 0.01 in optim.AdamW. The AdamW implementation’s decoupled approach is the intended way to apply L2 regularization to Adam-like optimizers.
The next hurdle is understanding how to tune the learning rate and weight decay hyperparameters for each optimizer, as their optimal values can differ significantly.