PyTorch’s torch.save and torch.load are your primary tools for saving and resuming training, but understanding what to save and how to structure it is key to avoiding data loss and ensuring smooth restarts.
Let’s see this in action. Imagine a simple training loop for a small neural network.
import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# Dummy data
inputs = torch.randn(32, 10)
labels = torch.randint(0, 2, (32,))
# --- Training Step ---
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Example training step complete. Loss: {loss.item():.4f}")
# --- Saving Checkpoint ---
checkpoint_path = "model_checkpoint.pt"
torch.save({
'epoch': 0,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")
# --- Resuming Training ---
# Simulate restarting the script later
print("\n--- Simulating restart ---")
# Re-initialize model and optimizer (as if starting a new script)
model_resume = SimpleModel()
optimizer_resume = optim.SGD(model_resume.parameters(), lr=0.01)
# Load the checkpoint
checkpoint = torch.load(checkpoint_path)
# Load state dictionaries
model_resume.load_state_dict(checkpoint['model_state_dict'])
optimizer_resume.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_resume = checkpoint['epoch']
loss_resume = checkpoint['loss']
print(f"Checkpoint loaded. Resuming from epoch {epoch_resume + 1} with last loss {loss_resume:.4f}")
# Now you would continue the training loop from here...
# For demonstration, let's just do one more step with the resumed model
optimizer_resume.zero_grad()
outputs_resume = model_resume(inputs)
loss_resume_step = criterion(outputs_resume, labels)
loss_resume_step.backward()
optimizer_resume.step()
print(f"Resumed training step complete. Loss: {loss_resume_step.item():.4f}")
The core problem checkpointing solves is the ability to stop training at any point and resume later without losing progress. This is crucial for long-running training jobs that might be interrupted by system restarts, power outages, or simply the need to evaluate the model periodically. It also enables techniques like learning rate scheduling, where you might want to adjust the learning rate based on the current epoch or a saved state.
Internally, torch.save serializes Python objects into a file, typically using Python’s pickle module. When saving, you’re essentially pickling a Python dictionary or a custom object. The most important pieces to save are the state_dict of your model and optimizer. The state_dict is a Python dictionary that maps each layer to its parameter tensors (weights and biases). For the optimizer, its state_dict contains the current learning rate, momentum buffers, and any state specific to the optimization algorithm (like the first and second moment estimates in Adam). Beyond these, you’ll want to save the current epoch number, any learning rate scheduler state, and perhaps the best validation metric achieved so far.
When resuming, torch.load deserializes these objects. You then load the saved state_dict back into your newly instantiated model and optimizer objects. It’s critical that the model and optimizer you instantiate before loading have the same architecture and configuration as when they were saved. If you change the model architecture (e.g., add a layer, change its size) or the optimizer’s parameters (e.g., change the base learning rate before loading the state dict), load_state_dict will raise an error because the keys or shapes of the tensors won’t match.
The learning rate scheduler is another common component to checkpoint. If you’re using torch.optim.lr_scheduler.ReduceLROnPlateau, you need to save its internal state. Similarly, for epoch-based schedulers like StepLR, you save the current epoch count and potentially the last step at which the learning rate was updated. When resuming, you’d call scheduler.step() at the appropriate point in your loop, and if it’s an epoch-based scheduler, you might need to pass the current epoch number to it. For ReduceLROnPlateau, you’d typically call scheduler.step(validation_loss) and ensure the internal state is loaded.
A common pitfall is forgetting to save the optimizer’s state. If you only save the model’s state_dict, resuming training will start with the optimizer initialized with its default parameters (e.g., the initial learning rate). This means you lose any progress the optimizer made in adapting its internal states, and your learning rate might not be what you expect based on the scheduler’s history. Saving and loading the optimizer’s state_dict ensures that the optimizer continues from precisely where it left off.
The most surprising thing about torch.save and torch.load is that they are not just for state_dicts; you can save and load any Python object. This means you could save your entire model instance, the optimizer instance, or even custom data structures. However, relying solely on saving entire objects can lead to portability issues and make it harder to inspect or modify parts of the saved state. Saving state_dicts within a structured dictionary is the idiomatic and recommended approach for training checkpoints because it decouples the saving mechanism from the specific instantiation of your classes, offering more flexibility.
When you save a checkpoint, you might be tempted to save the entire training dataset or validation set. This is generally a bad idea due to the massive file sizes it would create. Instead, save the indices or a subset identifier of the data used for that epoch, or simply rely on having the data files accessible on disk when you resume. The state_dict of the model and optimizer, along with the current epoch and any scheduler state, are usually sufficient to resume training effectively.
The next thing you’ll likely encounter is managing multiple checkpoints, perhaps saving the best model based on validation performance and implementing strategies for keeping a history of recent checkpoints.