Building ML APIs with PyTorch and FastAPI
Learn how to deploy PyTorch models in production using FastAPI for high-performance, scalable machine learning APIs
Building ML APIs with PyTorch and FastAPI
Machine learning models are only as valuable as their accessibility. In this guide, I'll walk you through deploying PyTorch models using FastAPI to create production-ready APIs that are both performant and scalable.
Why PyTorch + FastAPI?
PyTorch is my go-to deep learning framework for its:
- Intuitive Python-first design
- Dynamic computation graphs
- Strong community support
- Excellent debugging capabilities
FastAPI complements PyTorch perfectly with:
- Asynchronous request handling
- Automatic API documentation (OpenAPI/Swagger)
- Type hints and validation with Pydantic
- High performance (on par with Node.js and Go)
Project Structure
ml-api/
├── models/
│ └── trained_model.pth
├── app/
│ ├── main.py
│ ├── model.py
│ └── schemas.py
├── requirements.txt
└── Dockerfile
Loading the PyTorch Model
First, create a model loader that initializes once on startup:
# app/model.py
import torch
import torch.nn as nn
from typing import Optional
class ModelPredictor:
def __init__(self, model_path: str, device: Optional[str] = None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self._load_model(model_path)
self.model.eval()
def _load_model(self, path: str):
model = YourModelArchitecture()
model.load_state_dict(torch.load(path, map_location=self.device))
return model.to(self.device)
@torch.no_grad()
def predict(self, input_data: torch.Tensor):
input_data = input_data.to(self.device)
output = self.model(input_data)
return output.cpu().numpy()FastAPI Application
Create the API endpoints with proper type validation:
# app/main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import numpy as np
import torch
app = FastAPI(
title="PyTorch ML API",
description="Production ML inference API",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
model_predictor = None
@app.on_event("startup")
async def load_model():
global model_predictor
model_predictor = ModelPredictor("models/trained_model.pth")
print("Model loaded successfully")
class PredictionRequest(BaseModel):
data: list[float]
class PredictionResponse(BaseModel):
prediction: list[float]
confidence: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
# Convert input to tensor
input_tensor = torch.tensor([request.data], dtype=torch.float32)
# Get prediction
output = model_predictor.predict(input_tensor)
return PredictionResponse(
prediction=output[0].tolist(),
confidence=float(np.max(output))
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model_predictor is not None}Performance Optimizations
1. Model Quantization
Reduce model size and increase inference speed:
# Convert to quantized model
model_quantized = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)2. Batch Processing
Process multiple requests together:
@app.post("/batch-predict")
async def batch_predict(requests: list[PredictionRequest]):
batch_tensor = torch.stack([
torch.tensor(req.data, dtype=torch.float32)
for req in requests
])
outputs = model_predictor.predict(batch_tensor)
return [{"prediction": out.tolist()} for out in outputs]3. Async Workers
Use Uvicorn with multiple workers:
uvicorn app.main:app --workers 4 --host 0.0.0.0 --port 8000Real-World Application: Biometric Verification
At Shri Asharam Memorial Navjeevan Hospital, I implemented a similar architecture for a Siamese Neural Network-based biometric verification system:
- 98.6% accuracy across 5,000+ patient records
- FastAPI backend handling 400+ daily transactions
- PyTorch model for real-time patient verification
- 99.5% system uptime with robust error handling
The key was proper model caching, async request handling, and efficient tensor operations.
Deployment with Docker
FROM python:3.10-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]Key Takeaways
- Load models once at startup to avoid repeated disk I/O
- Use
@torch.no_grad()for inference to save memory - Leverage FastAPI's async capabilities for I/O operations
- Implement proper error handling and input validation
- Monitor performance with logging and metrics
Performance Results
In my production deployments:
- Average latency: < 50ms per request
- Throughput: 1000+ requests/second
- Memory footprint: ~500MB for quantized models
- CPU utilization: 30-40% with 4 workers
Next Steps
- Implement model versioning for A/B testing
- Add Redis caching for frequently requested inputs
- Set up monitoring with Prometheus and Grafana
- Deploy on Kubernetes for auto-scaling
FastAPI and PyTorch together create a powerful stack for deploying ML models in production. The combination of Python's flexibility and FastAPI's performance makes it ideal for real-world applications.
Have questions about deploying ML models? Feel free to reach out!
