Building Inference-as-a-Service with FastAPI
What We're Building
Prerequisites
Step 1: The Inference Server
# inference_server.py
"""FastAPI inference server for ML models."""
import os
import time
import asyncio
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager
from dataclasses import dataclass
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
# ============== Models ==============
class InferenceRequest(BaseModel):
"""Single inference request."""
inputs: str | List[str]
parameters: Optional[Dict[str, Any]] = None
class InferenceResponse(BaseModel):
"""Inference response."""
outputs: List[str] | List[Dict]
processing_time_ms: float
model: str
class HealthResponse(BaseModel):
"""Health check response."""
status: str
gpu_available: bool
gpu_name: Optional[str]
gpu_memory_used_gb: Optional[float]
gpu_memory_total_gb: Optional[float]
model_loaded: bool
requests_processed: int
uptime_seconds: float
# ============== Batch Processor ==============
@dataclass
class BatchItem:
"""Item in the batch queue."""
request_id: str
inputs: List[str]
parameters: Dict
future: asyncio.Future
class BatchProcessor:
"""Process inference requests in batches for efficiency."""
def __init__(self, max_batch_size: int = 8, max_wait_ms: int = 50):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.queue: List[BatchItem] = []
self.lock = asyncio.Lock()
self.processing = False
async def add_request(self, inputs: List[str], parameters: Dict) -> Any:
"""Add request to batch and wait for result."""
import uuid
future = asyncio.Future()
item = BatchItem(
request_id=str(uuid.uuid4()),
inputs=inputs,
parameters=parameters,
future=future
)
async with self.lock:
self.queue.append(item)
# Start processing if batch is full
if len(self.queue) >= self.max_batch_size:
asyncio.create_task(self._process_batch())
# Or schedule processing after wait time
elif len(self.queue) == 1:
asyncio.create_task(self._delayed_process())
return await future
async def _delayed_process(self):
"""Process batch after max_wait_ms."""
await asyncio.sleep(self.max_wait_ms / 1000)
if self.queue and not self.processing:
await self._process_batch()
async def _process_batch(self):
"""Process current batch."""
async with self.lock:
if not self.queue or self.processing:
return
self.processing = True
batch = self.queue[:self.max_batch_size]
self.queue = self.queue[self.max_batch_size:]
try:
# Combine all inputs
all_inputs = []
for item in batch:
all_inputs.extend(item.inputs)
# Run inference
results = await asyncio.get_event_loop().run_in_executor(
None, model_manager.generate, all_inputs, batch[0].parameters
)
# Distribute results
idx = 0
for item in batch:
n = len(item.inputs)
item.future.set_result(results[idx:idx + n])
idx += n
except Exception as e:
for item in batch:
item.future.set_exception(e)
finally:
self.processing = False
# ============== Model Manager ==============
class ModelManager:
"""Manage model loading and inference."""
def __init__(self):
self.model = None
self.tokenizer = None
self.model_name = ""
self.device = None
self.requests_processed = 0
self.start_time = time.time()
def load_model(self, model_name: str):
"""Load a model."""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print(f"Loading model: {model_name}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto"
)
self.model_name = model_name
print(f"Model loaded on {self.device}")
def generate(self, inputs: List[str], parameters: Dict = None) -> List[str]:
"""Generate text for batch of inputs."""
parameters = parameters or {}
# Tokenize
encoded = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=parameters.get("max_input_length", 512)
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**encoded,
max_new_tokens=parameters.get("max_new_tokens", 256),
temperature=parameters.get("temperature", 0.7),
top_p=parameters.get("top_p", 0.9),
do_sample=parameters.get("do_sample", True),
pad_token_id=self.tokenizer.pad_token_id
)
# Decode
results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Remove input from output
cleaned = []
for inp, out in zip(inputs, results):
if out.startswith(inp):
out = out[len(inp):].strip()
cleaned.append(out)
self.requests_processed += len(inputs)
return cleaned
def get_health(self) -> Dict:
"""Get health status."""
gpu_available = torch.cuda.is_available()
return {
"status": "healthy" if self.model else "not_loaded",
"gpu_available": gpu_available,
"gpu_name": torch.cuda.get_device_name(0) if gpu_available else None,
"gpu_memory_used_gb": torch.cuda.memory_allocated() / 1e9 if gpu_available else None,
"gpu_memory_total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_available else None,
"model_loaded": self.model is not None,
"requests_processed": self.requests_processed,
"uptime_seconds": time.time() - self.start_time
}
# ============== App ==============
model_manager = ModelManager()
batch_processor = BatchProcessor(max_batch_size=8, max_wait_ms=50)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifecycle management."""
# Startup
model_name = os.environ.get("MODEL_NAME", "gpt2")
model_manager.load_model(model_name)
yield
# Shutdown
pass
app = FastAPI(
title="Clore Inference API",
description="GPU-accelerated model inference",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return model_manager.get_health()
@app.post("/generate", response_model=InferenceResponse)
async def generate(request: InferenceRequest):
"""Generate text from input."""
if not model_manager.model:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
# Normalize inputs
inputs = request.inputs if isinstance(request.inputs, list) else [request.inputs]
parameters = request.parameters or {}
# Use batch processor
results = await batch_processor.add_request(inputs, parameters)
processing_time = (time.time() - start_time) * 1000
return InferenceResponse(
outputs=results,
processing_time_ms=processing_time,
model=model_manager.model_name
)
@app.post("/batch", response_model=InferenceResponse)
async def batch_generate(requests: List[InferenceRequest]):
"""Process multiple requests in one call."""
if not model_manager.model:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
# Combine all inputs
all_inputs = []
for req in requests:
inputs = req.inputs if isinstance(req.inputs, list) else [req.inputs]
all_inputs.extend(inputs)
parameters = requests[0].parameters or {}
results = await batch_processor.add_request(all_inputs, parameters)
processing_time = (time.time() - start_time) * 1000
return InferenceResponse(
outputs=results,
processing_time_ms=processing_time,
model=model_manager.model_name
)
if __name__ == "__main__":
uvicorn.run(
"inference_server:app",
host="0.0.0.0",
port=8000,
reload=False
)Step 2: Deployment Script
Step 3: Client SDK
Quick Start
Performance Benchmarks
Model
GPU
Tokens/sec
Latency (p50)
Batch Size
Cost Analysis
Model
Requests/hr
Cost/hr
Cost/1K requests
Next Steps
Last updated
Was this helpful?