torch.compile is your new best friend for PyTorch speedups, but it’s not just a magic bullet; it’s a sophisticated compiler that needs a little understanding to wield effectively.

Let’s see it in action. Imagine a simple feed-forward network.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(128, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create dummy data
batch_size = 32
input_features = 128
num_classes = 10
num_samples = 1000

data = torch.randn(num_samples, input_features)
labels = torch.randint(0, num_classes, (num_samples,))

dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Instantiate model and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# --- Without torch.compile ---
print("Running without torch.compile...")
model_no_compile = SimpleModel()
model_no_compile.load_state_dict(model.state_dict()) # Ensure same weights

for epoch in range(2):
    running_loss = 0.0
    for i, (inputs, targets) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = model_no_compile(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}")

# --- With torch.compile ---
print("\nRunning with torch.compile...")
# The magic happens here!
compiled_model = torch.compile(model)

# IMPORTANT: The first few calls to the compiled model will trigger compilation.
# It's best practice to run a "warm-up" pass.
# We'll simulate this by running one batch.
print("Warming up compiled model...")
warmup_inputs, warmup_targets = next(iter(dataloader))
_ = compiled_model(warmup_inputs)
print("Warm-up complete.")

# Now, run the training loop with the compiled model
for epoch in range(2):
    running_loss = 0.0
    for i, (inputs, targets) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = compiled_model(inputs) # Use the compiled model
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}")

The core idea behind torch.compile is to take your PyTorch model, which is typically a sequence of Python operations executed one by one, and transform it into a more optimized, lower-level representation. It uses two main components: TorchDynamo and TorchInductor. TorchDynamo acts as a Python bytecode interpreter that captures PyTorch operations as a graph. TorchInductor then takes this graph and generates highly optimized code, often leveraging low-level libraries like Triton for GPU kernels. This process can fuse operations, eliminate redundant computations, and exploit hardware parallelism more effectively than standard eager execution.

To make torch.compile work, you need to pass your nn.Module (or any callable that returns PyTorch tensors) to torch.compile(). The returned object is a callable that behaves identically to your original model but executes much faster after an initial compilation phase.

# Example:
import torch.nn as nn
import torch

model = nn.Linear(10, 5)
compiled_model = torch.compile(model)

# You can use compiled_model just like the original model
input_tensor = torch.randn(1, 10)
output = compiled_model(input_tensor)
print(output.shape)

The most surprising truth about torch.compile is that it doesn’t actually compile your entire Python model into a single, monolithic executable. Instead, it breaks your model into smaller "graph breaks" (regions of code that Dynamo can’t capture) and compiles each of these graphs independently. The final execution involves switching between the compiled graphs and the standard eager Python execution. This makes it incredibly flexible, allowing it to work with arbitrary Python control flow, but it also means that poorly placed graph breaks can limit performance gains.

The exact levers you control are primarily through the mode argument in torch.compile. The default is "default", which tries to balance compilation time and runtime performance. "reduce-overhead" is excellent for small models or short-running computations where the overhead of compilation is a significant portion of the total time. "max-autotune" spends more time searching for the absolute best kernel implementations for your specific hardware and workload, which can yield the highest performance but takes much longer to compile.

# Example with different modes:
compiled_model_max = torch.compile(model, mode="max-autotune")
compiled_model_reduce = torch.compile(model, mode="reduce-overhead")

One thing that often trips people up is understanding when compilation happens. It’s not instantaneous. The first time you call the compiled model with a specific set of input shapes and dtypes, TorchDynamo traces the execution, builds the graph, and then TorchInductor generates the optimized code. Subsequent calls with the same input signature will use the already-compiled code. This is why the "warm-up" phase in the example is crucial for accurate benchmarking. If you change the input shape significantly, a new compilation will be triggered for that new shape.

The next concept you’ll likely encounter is managing graph breaks and understanding how they impact performance, especially with complex control flow or dynamic operations.

Want structured learning?

Take the full Pytorch course →