# Building Inference-as-a-Service with FastAPI

## What We're Building

A production-ready model inference API using FastAPI, deployed on Clore GPUs with automatic batching, health checks, and cost-efficient scaling.

## Prerequisites

* Clore.ai API key
* Python 3.10+
* Basic FastAPI knowledge

## Step 1: The Inference Server

```python
# inference_server.py
"""FastAPI inference server for ML models."""

import os
import time
import asyncio
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager
from dataclasses import dataclass
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn


# ============== Models ==============

class InferenceRequest(BaseModel):
    """Single inference request."""
    inputs: str | List[str]
    parameters: Optional[Dict[str, Any]] = None

class InferenceResponse(BaseModel):
    """Inference response."""
    outputs: List[str] | List[Dict]
    processing_time_ms: float
    model: str

class HealthResponse(BaseModel):
    """Health check response."""
    status: str
    gpu_available: bool
    gpu_name: Optional[str]
    gpu_memory_used_gb: Optional[float]
    gpu_memory_total_gb: Optional[float]
    model_loaded: bool
    requests_processed: int
    uptime_seconds: float


# ============== Batch Processor ==============

@dataclass
class BatchItem:
    """Item in the batch queue."""
    request_id: str
    inputs: List[str]
    parameters: Dict
    future: asyncio.Future

class BatchProcessor:
    """Process inference requests in batches for efficiency."""
    
    def __init__(self, max_batch_size: int = 8, max_wait_ms: int = 50):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue: List[BatchItem] = []
        self.lock = asyncio.Lock()
        self.processing = False
    
    async def add_request(self, inputs: List[str], parameters: Dict) -> Any:
        """Add request to batch and wait for result."""
        import uuid
        
        future = asyncio.Future()
        item = BatchItem(
            request_id=str(uuid.uuid4()),
            inputs=inputs,
            parameters=parameters,
            future=future
        )
        
        async with self.lock:
            self.queue.append(item)
            
            # Start processing if batch is full
            if len(self.queue) >= self.max_batch_size:
                asyncio.create_task(self._process_batch())
            # Or schedule processing after wait time
            elif len(self.queue) == 1:
                asyncio.create_task(self._delayed_process())
        
        return await future
    
    async def _delayed_process(self):
        """Process batch after max_wait_ms."""
        await asyncio.sleep(self.max_wait_ms / 1000)
        if self.queue and not self.processing:
            await self._process_batch()
    
    async def _process_batch(self):
        """Process current batch."""
        async with self.lock:
            if not self.queue or self.processing:
                return
            
            self.processing = True
            batch = self.queue[:self.max_batch_size]
            self.queue = self.queue[self.max_batch_size:]
        
        try:
            # Combine all inputs
            all_inputs = []
            for item in batch:
                all_inputs.extend(item.inputs)
            
            # Run inference
            results = await asyncio.get_event_loop().run_in_executor(
                None, model_manager.generate, all_inputs, batch[0].parameters
            )
            
            # Distribute results
            idx = 0
            for item in batch:
                n = len(item.inputs)
                item.future.set_result(results[idx:idx + n])
                idx += n
                
        except Exception as e:
            for item in batch:
                item.future.set_exception(e)
        finally:
            self.processing = False


# ============== Model Manager ==============

class ModelManager:
    """Manage model loading and inference."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = ""
        self.device = None
        self.requests_processed = 0
        self.start_time = time.time()
    
    def load_model(self, model_name: str):
        """Load a model."""
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch
        
        print(f"Loading model: {model_name}")
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
            device_map="auto"
        )
        
        self.model_name = model_name
        print(f"Model loaded on {self.device}")
    
    def generate(self, inputs: List[str], parameters: Dict = None) -> List[str]:
        """Generate text for batch of inputs."""
        parameters = parameters or {}
        
        # Tokenize
        encoded = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=parameters.get("max_input_length", 512)
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **encoded,
                max_new_tokens=parameters.get("max_new_tokens", 256),
                temperature=parameters.get("temperature", 0.7),
                top_p=parameters.get("top_p", 0.9),
                do_sample=parameters.get("do_sample", True),
                pad_token_id=self.tokenizer.pad_token_id
            )
        
        # Decode
        results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Remove input from output
        cleaned = []
        for inp, out in zip(inputs, results):
            if out.startswith(inp):
                out = out[len(inp):].strip()
            cleaned.append(out)
        
        self.requests_processed += len(inputs)
        return cleaned
    
    def get_health(self) -> Dict:
        """Get health status."""
        gpu_available = torch.cuda.is_available()
        
        return {
            "status": "healthy" if self.model else "not_loaded",
            "gpu_available": gpu_available,
            "gpu_name": torch.cuda.get_device_name(0) if gpu_available else None,
            "gpu_memory_used_gb": torch.cuda.memory_allocated() / 1e9 if gpu_available else None,
            "gpu_memory_total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_available else None,
            "model_loaded": self.model is not None,
            "requests_processed": self.requests_processed,
            "uptime_seconds": time.time() - self.start_time
        }


# ============== App ==============

model_manager = ModelManager()
batch_processor = BatchProcessor(max_batch_size=8, max_wait_ms=50)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifecycle management."""
    # Startup
    model_name = os.environ.get("MODEL_NAME", "gpt2")
    model_manager.load_model(model_name)
    yield
    # Shutdown
    pass

app = FastAPI(
    title="Clore Inference API",
    description="GPU-accelerated model inference",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    return model_manager.get_health()


@app.post("/generate", response_model=InferenceResponse)
async def generate(request: InferenceRequest):
    """Generate text from input."""
    
    if not model_manager.model:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = time.time()
    
    # Normalize inputs
    inputs = request.inputs if isinstance(request.inputs, list) else [request.inputs]
    parameters = request.parameters or {}
    
    # Use batch processor
    results = await batch_processor.add_request(inputs, parameters)
    
    processing_time = (time.time() - start_time) * 1000
    
    return InferenceResponse(
        outputs=results,
        processing_time_ms=processing_time,
        model=model_manager.model_name
    )


@app.post("/batch", response_model=InferenceResponse)
async def batch_generate(requests: List[InferenceRequest]):
    """Process multiple requests in one call."""
    
    if not model_manager.model:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = time.time()
    
    # Combine all inputs
    all_inputs = []
    for req in requests:
        inputs = req.inputs if isinstance(req.inputs, list) else [req.inputs]
        all_inputs.extend(inputs)
    
    parameters = requests[0].parameters or {}
    results = await batch_processor.add_request(all_inputs, parameters)
    
    processing_time = (time.time() - start_time) * 1000
    
    return InferenceResponse(
        outputs=results,
        processing_time_ms=processing_time,
        model=model_manager.model_name
    )


if __name__ == "__main__":
    uvicorn.run(
        "inference_server:app",
        host="0.0.0.0",
        port=8000,
        reload=False
    )
```

## Step 2: Deployment Script

```python
# deploy_inference.py
"""Deploy inference server on Clore."""

import os
import sys
import time
import requests
import paramiko
from scp import SCPClient

class InferenceDeployer:
    """Deploy inference API on Clore GPUs."""
    
    BASE_URL = "https://api.clore.ai"
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.headers = {"auth": api_key}
        self.ssh_client = None
        self.order_id = None
        self.server_host = None
        self.server_port = None
    
    def _api(self, method: str, endpoint: str, **kwargs) -> dict:
        url = f"{self.BASE_URL}{endpoint}"
        response = requests.request(method, url, headers=self.headers, **kwargs)
        data = response.json()
        if data.get("code") != 0:
            raise Exception(f"API Error: {data}")
        return data
    
    def find_server(self, gpu_type: str, max_price: float) -> dict:
        """Find a suitable server."""
        servers = self._api("GET", "/v1/marketplace")["servers"]
        
        for server in servers:
            if server.get("rented"):
                continue
            
            gpus = server.get("gpu_array", [])
            if not any(gpu_type in g for g in gpus):
                continue
            
            price = server.get("price", {}).get("usd", {}).get("on_demand_clore", 999)
            if price <= max_price:
                return {
                    "id": server["id"],
                    "gpus": gpus,
                    "price": price
                }
        
        raise Exception(f"No {gpu_type} found under ${max_price}/hr")
    
    def provision(self, server_id: int, ssh_password: str) -> dict:
        """Provision the server."""
        order = self._api("POST", "/v1/create_order", json={
            "renting_server": server_id,
            "type": "on-demand",
            "currency": "CLORE-Blockchain",
            "image": "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel",
            "ports": {"22": "tcp", "8000": "http"},
            "env": {"NVIDIA_VISIBLE_DEVICES": "all"},
            "ssh_password": ssh_password
        })
        
        self.order_id = order["order_id"]
        print(f"📦 Order: {self.order_id}")
        
        # Wait for ready
        for _ in range(120):
            orders = self._api("GET", "/v1/my_orders")["orders"]
            current = next((o for o in orders if o["order_id"] == self.order_id), None)
            if current and current.get("status") == "running":
                return current
            time.sleep(2)
        
        raise Exception("Timeout")
    
    def connect(self, host: str, port: int, password: str):
        """Connect via SSH."""
        self.ssh_client = paramiko.SSHClient()
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        
        for _ in range(5):
            try:
                self.ssh_client.connect(host, port=port, username="root",
                                        password=password, timeout=30)
                self.server_host = host
                self.server_port = port
                print(f"✅ Connected to {host}:{port}")
                return
            except:
                time.sleep(10)
        
        raise Exception("SSH failed")
    
    def run_cmd(self, cmd: str, stream: bool = False) -> str:
        """Run command on server."""
        _, stdout, stderr = self.ssh_client.exec_command(cmd, get_pty=True)
        
        if stream:
            output = ""
            for line in iter(stdout.readline, ""):
                print(line, end="")
                output += line
            return output
        return stdout.read().decode()
    
    def setup(self, model_name: str):
        """Setup the inference environment."""
        print("\n🔧 Setting up environment...")
        
        commands = [
            "pip install --upgrade pip",
            "pip install fastapi uvicorn transformers accelerate",
            "mkdir -p /workspace"
        ]
        
        for cmd in commands:
            self.run_cmd(cmd)
        
        # Set model name in environment
        self.run_cmd(f"echo 'export MODEL_NAME={model_name}' >> ~/.bashrc")
        
        print("✅ Environment ready")
    
    def upload_server(self):
        """Upload inference server code."""
        script_path = os.path.join(os.path.dirname(__file__), "inference_server.py")
        
        with SCPClient(self.ssh_client.get_transport()) as scp:
            scp.put(script_path, "/workspace/inference_server.py")
        
        print("📤 Server code uploaded")
    
    def start_server(self, model_name: str):
        """Start the inference server."""
        print("\n🚀 Starting inference server...")
        
        cmd = f"""
cd /workspace
export MODEL_NAME='{model_name}'
nohup python inference_server.py > /var/log/inference.log 2>&1 &
"""
        self.run_cmd(cmd)
        
        # Wait for server to be ready
        time.sleep(30)  # Allow model to load
        
        print(f"✅ Server started!")
    
    def get_api_url(self, order: dict) -> str:
        """Get the API URL."""
        ports = order.get("connection", {}).get("ports", {})
        for port_info in ports.values():
            if "8000" in str(port_info):
                return port_info
        return f"http://{self.server_host}:8000"
    
    def deploy(
        self,
        model_name: str = "gpt2",
        gpu_type: str = "RTX 4090",
        max_price: float = 0.50,
        ssh_password: str = "InferAPI123!"
    ) -> dict:
        """Deploy complete inference API."""
        
        try:
            # Find server
            print(f"🔍 Finding {gpu_type}...")
            server = self.find_server(gpu_type, max_price)
            print(f"   Found: Server {server['id']} @ ${server['price']:.2f}/hr")
            
            # Provision
            print("\n📦 Provisioning...")
            order = self.provision(server["id"], ssh_password)
            
            # Connect
            ssh_info = order["connection"]["ssh"]
            parts = ssh_info.split()
            host = parts[1].split("@")[1]
            port = int(parts[3]) if len(parts) > 3 else 22
            
            self.connect(host, port, ssh_password)
            
            # Setup
            self.setup(model_name)
            
            # Upload and start
            self.upload_server()
            self.start_server(model_name)
            
            api_url = self.get_api_url(order)
            
            print("\n" + "="*50)
            print("✅ Inference API Deployed!")
            print("="*50)
            print(f"\n🌐 API URL: {api_url}")
            print(f"📊 Health: {api_url}/health")
            print(f"🤖 Generate: POST {api_url}/generate")
            print(f"\n💰 Cost: ${server['price']:.2f}/hr")
            print(f"📋 Order ID: {self.order_id}")
            
            return {
                "api_url": api_url,
                "order_id": self.order_id,
                "price_per_hour": server["price"],
                "model": model_name
            }
            
        except Exception as e:
            print(f"❌ Deployment failed: {e}")
            self.shutdown()
            raise
    
    def shutdown(self):
        """Shutdown the deployment."""
        if self.ssh_client:
            self.ssh_client.close()
        
        if self.order_id:
            self._api("POST", "/v1/cancel_order", json={"id": self.order_id})
            print(f"✅ Order {self.order_id} cancelled")


def main():
    api_key = os.environ.get("CLORE_API_KEY") or sys.argv[1]
    model = sys.argv[2] if len(sys.argv) > 2 else "gpt2"
    
    deployer = InferenceDeployer(api_key)
    
    try:
        result = deployer.deploy(
            model_name=model,
            gpu_type="RTX 4090",
            max_price=0.50
        )
        
        # Test the API
        print("\n🧪 Testing API...")
        time.sleep(10)
        
        response = requests.post(
            f"{result['api_url']}/generate",
            json={"inputs": "Hello, I am a"}
        )
        
        if response.status_code == 200:
            print(f"   Response: {response.json()['outputs'][0][:100]}...")
        else:
            print(f"   Test failed: {response.status_code}")
        
        input("\n⏸️  Press Enter to shutdown...")
        
    finally:
        deployer.shutdown()


if __name__ == "__main__":
    main()
```

## Step 3: Client SDK

```python
# inference_client.py
"""Client SDK for Clore inference API."""

import requests
from typing import List, Dict, Optional
import time

class InferenceClient:
    """Client for Clore inference API."""
    
    def __init__(self, api_url: str, timeout: int = 60):
        self.api_url = api_url.rstrip("/")
        self.timeout = timeout
    
    def health(self) -> Dict:
        """Check API health."""
        response = requests.get(f"{self.api_url}/health", timeout=10)
        response.raise_for_status()
        return response.json()
    
    def generate(
        self,
        inputs: str | List[str],
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        **kwargs
    ) -> List[str]:
        """Generate text."""
        
        parameters = {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            **kwargs
        }
        
        response = requests.post(
            f"{self.api_url}/generate",
            json={"inputs": inputs, "parameters": parameters},
            timeout=self.timeout
        )
        response.raise_for_status()
        
        return response.json()["outputs"]
    
    def batch_generate(
        self,
        inputs_list: List[str],
        batch_size: int = 8,
        **kwargs
    ) -> List[str]:
        """Generate for many inputs efficiently."""
        
        results = []
        for i in range(0, len(inputs_list), batch_size):
            batch = inputs_list[i:i + batch_size]
            batch_results = self.generate(batch, **kwargs)
            results.extend(batch_results)
        
        return results
    
    def wait_for_ready(self, timeout: int = 120) -> bool:
        """Wait for API to be ready."""
        start = time.time()
        while time.time() - start < timeout:
            try:
                health = self.health()
                if health.get("model_loaded"):
                    return True
            except:
                pass
            time.sleep(5)
        return False


# Usage example
if __name__ == "__main__":
    client = InferenceClient("http://your-clore-server:8000")
    
    # Wait for ready
    print("Waiting for API...")
    if client.wait_for_ready():
        print("API ready!")
        
        # Single generation
        result = client.generate("The future of AI is")
        print(f"Output: {result[0]}")
        
        # Batch generation
        prompts = [
            "Write a haiku about",
            "The best programming language is",
            "In the year 2050,"
        ]
        results = client.batch_generate(prompts)
        for prompt, result in zip(prompts, results):
            print(f"{prompt}: {result[:50]}...")
```

## Quick Start

```bash
# Deploy inference API
export CLORE_API_KEY="your_key"
python deploy_inference.py $CLORE_API_KEY gpt2

# Test with curl
curl http://your-api-url/generate \
  -H "Content-Type: application/json" \
  -d '{"inputs": "Hello world", "parameters": {"max_new_tokens": 50}}'
```

## Performance Benchmarks

| Model      | GPU      | Tokens/sec | Latency (p50) | Batch Size |
| ---------- | -------- | ---------- | ------------- | ---------- |
| GPT-2      | RTX 4090 | \~150      | 45ms          | 8          |
| Llama 7B   | RTX 4090 | \~35       | 120ms         | 4          |
| Mistral 7B | RTX 4090 | \~40       | 100ms         | 4          |

## Cost Analysis

| Model      | Requests/hr | Cost/hr | Cost/1K requests |
| ---------- | ----------- | ------- | ---------------- |
| GPT-2      | \~10,000    | $0.35   | $0.035           |
| Llama 7B   | \~2,000     | $0.40   | $0.20            |
| Mistral 7B | \~2,500     | $0.40   | $0.16            |

*Compared to OpenAI API at \~$0.50-2.00/1K requests*

> 📚 See also: [How to Deploy an LLM Inference Server on Clore.ai](https://blog.clore.ai/how-to-deploy-llm-inference-server/)

## Next Steps

* [Batch Inference at Scale](https://docs.clore.ai/dev/inference-and-deployment/batch-inference)
* [Auto-Scaling Inference Workers](https://docs.clore.ai/dev/inference-and-deployment/auto-scaling-workers)
* [Multi-Model Inference Router](https://docs.clore.ai/dev/inference-and-deployment/model-router)
