Back to Blog
Building ML APIs with PyTorch and FastAPI
PyTorchFastAPIMachine LearningPythonAPI Development

Building ML APIs with PyTorch and FastAPI

Learn how to deploy PyTorch models in production using FastAPI for high-performance, scalable machine learning APIs

Building ML APIs with PyTorch and FastAPI

Machine learning models are only as valuable as their accessibility. In this guide, I'll walk you through deploying PyTorch models using FastAPI to create production-ready APIs that are both performant and scalable.

Why PyTorch + FastAPI?

PyTorch is my go-to deep learning framework for its:

  • Intuitive Python-first design
  • Dynamic computation graphs
  • Strong community support
  • Excellent debugging capabilities

FastAPI complements PyTorch perfectly with:

  • Asynchronous request handling
  • Automatic API documentation (OpenAPI/Swagger)
  • Type hints and validation with Pydantic
  • High performance (on par with Node.js and Go)

Project Structure

ml-api/
├── models/
│   └── trained_model.pth
├── app/
│   ├── main.py
│   ├── model.py
│   └── schemas.py
├── requirements.txt
└── Dockerfile

Loading the PyTorch Model

First, create a model loader that initializes once on startup:

# app/model.py
import torch
import torch.nn as nn
from typing import Optional

class ModelPredictor:
    def __init__(self, model_path: str, device: Optional[str] = None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self._load_model(model_path)
        self.model.eval()
    
    def _load_model(self, path: str):
        model = YourModelArchitecture()
        model.load_state_dict(torch.load(path, map_location=self.device))
        return model.to(self.device)
    
    @torch.no_grad()
    def predict(self, input_data: torch.Tensor):
        input_data = input_data.to(self.device)
        output = self.model(input_data)
        return output.cpu().numpy()

FastAPI Application

Create the API endpoints with proper type validation:

# app/main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import numpy as np
import torch

app = FastAPI(
    title="PyTorch ML API",
    description="Production ML inference API",
    version="1.0.0"
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model instance
model_predictor = None

@app.on_event("startup")
async def load_model():
    global model_predictor
    model_predictor = ModelPredictor("models/trained_model.pth")
    print("Model loaded successfully")

class PredictionRequest(BaseModel):
    data: list[float]
    
class PredictionResponse(BaseModel):
    prediction: list[float]
    confidence: float

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    try:
        # Convert input to tensor
        input_tensor = torch.tensor([request.data], dtype=torch.float32)
        
        # Get prediction
        output = model_predictor.predict(input_tensor)
        
        return PredictionResponse(
            prediction=output[0].tolist(),
            confidence=float(np.max(output))
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": model_predictor is not None}

Performance Optimizations

1. Model Quantization

Reduce model size and increase inference speed:

# Convert to quantized model
model_quantized = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

2. Batch Processing

Process multiple requests together:

@app.post("/batch-predict")
async def batch_predict(requests: list[PredictionRequest]):
    batch_tensor = torch.stack([
        torch.tensor(req.data, dtype=torch.float32) 
        for req in requests
    ])
    outputs = model_predictor.predict(batch_tensor)
    return [{"prediction": out.tolist()} for out in outputs]

3. Async Workers

Use Uvicorn with multiple workers:

uvicorn app.main:app --workers 4 --host 0.0.0.0 --port 8000

Real-World Application: Biometric Verification

At Shri Asharam Memorial Navjeevan Hospital, I implemented a similar architecture for a Siamese Neural Network-based biometric verification system:

  • 98.6% accuracy across 5,000+ patient records
  • FastAPI backend handling 400+ daily transactions
  • PyTorch model for real-time patient verification
  • 99.5% system uptime with robust error handling

The key was proper model caching, async request handling, and efficient tensor operations.

Deployment with Docker

FROM python:3.10-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

Key Takeaways

  1. Load models once at startup to avoid repeated disk I/O
  2. Use @torch.no_grad() for inference to save memory
  3. Leverage FastAPI's async capabilities for I/O operations
  4. Implement proper error handling and input validation
  5. Monitor performance with logging and metrics

Performance Results

In my production deployments:

  • Average latency: < 50ms per request
  • Throughput: 1000+ requests/second
  • Memory footprint: ~500MB for quantized models
  • CPU utilization: 30-40% with 4 workers

Next Steps

  • Implement model versioning for A/B testing
  • Add Redis caching for frequently requested inputs
  • Set up monitoring with Prometheus and Grafana
  • Deploy on Kubernetes for auto-scaling

FastAPI and PyTorch together create a powerful stack for deploying ML models in production. The combination of Python's flexibility and FastAPI's performance makes it ideal for real-world applications.


Have questions about deploying ML models? Feel free to reach out!

Design & Developed by Shivratan Choudhary
© 2025. All rights reserved.