PyTorch Serving: Deploy Models with FastAPI
The most surprising thing about deploying PyTorch models with FastAPI is how much of the heavy lifting is handled by standard Python web frameworks, not specialized ML serving infrastructure.
Let’s see it in action. Imagine we have a simple PyTorch model that predicts sentiment from text.
import torch
import torch.nn as nn
class SimpleSentimentModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, text):
embedded = self.embedding(text)
output, (hidden, cell) = self.rnn(embedded)
# Use the hidden state from the last time step
last_hidden = hidden[-1, :, :]
prediction = self.fc(last_hidden)
return self.sigmoid(prediction)
# Assume we have a pre-trained model and a tokenizer
# For demonstration, let's just create a dummy model and tokenizer
VOCAB_SIZE = 1000
EMBEDDING_DIM = 64
HIDDEN_DIM = 32
model = SimpleSentimentModel(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM)
# In a real scenario, you'd load your trained model:
# model.load_state_dict(torch.load('sentiment_model.pth'))
model.eval() # Set model to evaluation mode
# Dummy tokenizer (in reality, use something like Hugging Face's tokenizer)
def tokenize(text):
# Simple tokenization: split by space, map to dummy IDs
tokens = text.lower().split()
# Map to IDs, assuming a vocabulary
token_ids = [abs(hash(token)) % VOCAB_SIZE for token in tokens]
return torch.tensor(token_ids).unsqueeze(0) # Add batch dimension
Now, let’s build the FastAPI application to serve this model.
from fastapi import FastAPI
from pydantic import BaseModel
import torch
# Assume model and tokenize function are defined as above
app = FastAPI(title="PyTorch Sentiment API")
class TextInput(BaseModel):
text: str
@app.post("/predict/")
async def predict_sentiment(request: TextInput):
# Preprocess the input text
tokenized_text = tokenize(request.text)
# Perform inference
with torch.no_grad(): # Disable gradient calculation for inference
prediction_tensor = model(tokenized_text)
# Postprocess the output
# The output is a probability between 0 and 1
sentiment_score = prediction_tensor.item()
sentiment = "positive" if sentiment_score > 0.5 else "negative"
return {"text": request.text, "sentiment": sentiment, "score": sentiment_score}
@app.get("/")
async def read_root():
return {"message": "Welcome to the PyTorch Sentiment API. POST to /predict/ with JSON body {text: 'your text here'}."}
# To run this:
# 1. Save the code as main.py
# 2. Install uvicorn: pip install uvicorn fastapi torch
# 3. Run from your terminal: uvicorn main:app --reload
This setup allows you to send a POST request to /predict/ with a JSON body like {"text": "This is a great movie!"} and receive a JSON response like {"text": "This is a great movie!", "sentiment": "positive", "score": 0.923456789}.
The core problem this solves is making your trained PyTorch model accessible over the network, enabling other applications or services to leverage its predictions. FastAPI provides the web server, request/response handling, and API documentation (Swagger UI at /docs). PyTorch handles the model execution.
Internally, when a request comes in:
- FastAPI receives the HTTP POST request at
/predict/. - It parses the JSON body and validates it against the
TextInputPydantic model. - The
predict_sentimentfunction is called with the validated input. - The input text is tokenized using our dummy function.
- The tokenized data is passed to the PyTorch
modelfor forward inference.torch.no_grad()is crucial here to prevent memory leaks and speed up inference by not computing gradients. - The model’s output (a tensor) is processed to extract the final sentiment prediction and score.
- FastAPI serializes the Python dictionary containing the results back into JSON and sends it as the HTTP response.
The exact levers you control are:
- Model Loading: Where and when
model.load_state_dict()ortorch.load()is called. It’s common to load the model once when the FastAPI app starts. - Preprocessing: The
tokenizefunction. This needs to match exactly how your model was trained. - Inference Logic: The
with torch.no_grad(): model(tokenized_text)part. You might add batching here for higher throughput. - Postprocessing: Converting the raw model output tensor into a human-readable prediction (e.g., sentiment label, class ID, bounding box coordinates).
- API Endpoint Design: The URL (
/predict/), HTTP method (POST), request body structure (TextInput), and response body structure.
One common pattern for optimizing inference performance is to move the model to a GPU if available. This is typically done once during application startup.
# At the top, after model definition and before app instantiation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Inside the predict_sentiment function, ensure input is on the same device:
tokenized_text = tokenize(request.text).to(device) # Move tokenized input to device
This ensures that every inference request leverages the GPU acceleration if present, significantly speeding up predictions for larger models or high-traffic scenarios.
The next step in building a robust ML serving system is often implementing more sophisticated batching strategies and model versioning.