Building Inference-as-a-Service with FastAPI

What We're Building

A production-ready model inference API using FastAPI, deployed on Clore GPUs with automatic batching, health checks, and cost-efficient scaling.

Prerequisites

  • Clore.ai API key

  • Python 3.10+

  • Basic FastAPI knowledge

Step 1: The Inference Server

# inference_server.py
"""FastAPI inference server for ML models."""

import os
import time
import asyncio
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager
from dataclasses import dataclass
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn


# ============== Models ==============

class InferenceRequest(BaseModel):
    """Single inference request."""
    inputs: str | List[str]
    parameters: Optional[Dict[str, Any]] = None

class InferenceResponse(BaseModel):
    """Inference response."""
    outputs: List[str] | List[Dict]
    processing_time_ms: float
    model: str

class HealthResponse(BaseModel):
    """Health check response."""
    status: str
    gpu_available: bool
    gpu_name: Optional[str]
    gpu_memory_used_gb: Optional[float]
    gpu_memory_total_gb: Optional[float]
    model_loaded: bool
    requests_processed: int
    uptime_seconds: float


# ============== Batch Processor ==============

@dataclass
class BatchItem:
    """Item in the batch queue."""
    request_id: str
    inputs: List[str]
    parameters: Dict
    future: asyncio.Future

class BatchProcessor:
    """Process inference requests in batches for efficiency."""
    
    def __init__(self, max_batch_size: int = 8, max_wait_ms: int = 50):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue: List[BatchItem] = []
        self.lock = asyncio.Lock()
        self.processing = False
    
    async def add_request(self, inputs: List[str], parameters: Dict) -> Any:
        """Add request to batch and wait for result."""
        import uuid
        
        future = asyncio.Future()
        item = BatchItem(
            request_id=str(uuid.uuid4()),
            inputs=inputs,
            parameters=parameters,
            future=future
        )
        
        async with self.lock:
            self.queue.append(item)
            
            # Start processing if batch is full
            if len(self.queue) >= self.max_batch_size:
                asyncio.create_task(self._process_batch())
            # Or schedule processing after wait time
            elif len(self.queue) == 1:
                asyncio.create_task(self._delayed_process())
        
        return await future
    
    async def _delayed_process(self):
        """Process batch after max_wait_ms."""
        await asyncio.sleep(self.max_wait_ms / 1000)
        if self.queue and not self.processing:
            await self._process_batch()
    
    async def _process_batch(self):
        """Process current batch."""
        async with self.lock:
            if not self.queue or self.processing:
                return
            
            self.processing = True
            batch = self.queue[:self.max_batch_size]
            self.queue = self.queue[self.max_batch_size:]
        
        try:
            # Combine all inputs
            all_inputs = []
            for item in batch:
                all_inputs.extend(item.inputs)
            
            # Run inference
            results = await asyncio.get_event_loop().run_in_executor(
                None, model_manager.generate, all_inputs, batch[0].parameters
            )
            
            # Distribute results
            idx = 0
            for item in batch:
                n = len(item.inputs)
                item.future.set_result(results[idx:idx + n])
                idx += n
                
        except Exception as e:
            for item in batch:
                item.future.set_exception(e)
        finally:
            self.processing = False


# ============== Model Manager ==============

class ModelManager:
    """Manage model loading and inference."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = ""
        self.device = None
        self.requests_processed = 0
        self.start_time = time.time()
    
    def load_model(self, model_name: str):
        """Load a model."""
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch
        
        print(f"Loading model: {model_name}")
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
            device_map="auto"
        )
        
        self.model_name = model_name
        print(f"Model loaded on {self.device}")
    
    def generate(self, inputs: List[str], parameters: Dict = None) -> List[str]:
        """Generate text for batch of inputs."""
        parameters = parameters or {}
        
        # Tokenize
        encoded = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=parameters.get("max_input_length", 512)
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **encoded,
                max_new_tokens=parameters.get("max_new_tokens", 256),
                temperature=parameters.get("temperature", 0.7),
                top_p=parameters.get("top_p", 0.9),
                do_sample=parameters.get("do_sample", True),
                pad_token_id=self.tokenizer.pad_token_id
            )
        
        # Decode
        results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Remove input from output
        cleaned = []
        for inp, out in zip(inputs, results):
            if out.startswith(inp):
                out = out[len(inp):].strip()
            cleaned.append(out)
        
        self.requests_processed += len(inputs)
        return cleaned
    
    def get_health(self) -> Dict:
        """Get health status."""
        gpu_available = torch.cuda.is_available()
        
        return {
            "status": "healthy" if self.model else "not_loaded",
            "gpu_available": gpu_available,
            "gpu_name": torch.cuda.get_device_name(0) if gpu_available else None,
            "gpu_memory_used_gb": torch.cuda.memory_allocated() / 1e9 if gpu_available else None,
            "gpu_memory_total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_available else None,
            "model_loaded": self.model is not None,
            "requests_processed": self.requests_processed,
            "uptime_seconds": time.time() - self.start_time
        }


# ============== App ==============

model_manager = ModelManager()
batch_processor = BatchProcessor(max_batch_size=8, max_wait_ms=50)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifecycle management."""
    # Startup
    model_name = os.environ.get("MODEL_NAME", "gpt2")
    model_manager.load_model(model_name)
    yield
    # Shutdown
    pass

app = FastAPI(
    title="Clore Inference API",
    description="GPU-accelerated model inference",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    return model_manager.get_health()


@app.post("/generate", response_model=InferenceResponse)
async def generate(request: InferenceRequest):
    """Generate text from input."""
    
    if not model_manager.model:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = time.time()
    
    # Normalize inputs
    inputs = request.inputs if isinstance(request.inputs, list) else [request.inputs]
    parameters = request.parameters or {}
    
    # Use batch processor
    results = await batch_processor.add_request(inputs, parameters)
    
    processing_time = (time.time() - start_time) * 1000
    
    return InferenceResponse(
        outputs=results,
        processing_time_ms=processing_time,
        model=model_manager.model_name
    )


@app.post("/batch", response_model=InferenceResponse)
async def batch_generate(requests: List[InferenceRequest]):
    """Process multiple requests in one call."""
    
    if not model_manager.model:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = time.time()
    
    # Combine all inputs
    all_inputs = []
    for req in requests:
        inputs = req.inputs if isinstance(req.inputs, list) else [req.inputs]
        all_inputs.extend(inputs)
    
    parameters = requests[0].parameters or {}
    results = await batch_processor.add_request(all_inputs, parameters)
    
    processing_time = (time.time() - start_time) * 1000
    
    return InferenceResponse(
        outputs=results,
        processing_time_ms=processing_time,
        model=model_manager.model_name
    )


if __name__ == "__main__":
    uvicorn.run(
        "inference_server:app",
        host="0.0.0.0",
        port=8000,
        reload=False
    )

Step 2: Deployment Script

Step 3: Client SDK

Quick Start

Performance Benchmarks

Model
GPU
Tokens/sec
Latency (p50)
Batch Size

GPT-2

RTX 4090

~150

45ms

8

Llama 7B

RTX 4090

~35

120ms

4

Mistral 7B

RTX 4090

~40

100ms

4

Cost Analysis

Model
Requests/hr
Cost/hr
Cost/1K requests

GPT-2

~10,000

$0.35

$0.035

Llama 7B

~2,000

$0.40

$0.20

Mistral 7B

~2,500

$0.40

$0.16

Compared to OpenAI API at ~$0.50-2.00/1K requests

📚 See also: How to Deploy an LLM Inference Server on Clore.aiarrow-up-right

Next Steps

Last updated

Was this helpful?