# Auto-Scaling Inference Workers

## What We're Building

An auto-scaling system that dynamically provisions and de-provisions Clore.ai GPU workers based on queue depth. Scale your inference capacity up when demand increases and down when it drops — paying only for what you use.

**Key Features:**

* Queue-based scaling (Redis/SQS/RabbitMQ)
* Configurable scaling thresholds
* Min/max worker limits
* Cool-down periods to prevent thrashing
* Cost tracking and budgeting
* Graceful worker shutdown
* Health monitoring

## Prerequisites

* Clore.ai account with API key
* Python 3.10+
* Redis or other message queue

```bash
pip install requests redis
```

## Architecture Overview

```
┌─────────────────────────────────────────────────────────────────┐
│                    Auto-Scaling System                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   ┌──────────────┐                                              │
│   │  Producers   │                                              │
│   │  (API/Jobs)  │                                              │
│   └──────┬───────┘                                              │
│          │                                                       │
│          ▼                                                       │
│   ┌──────────────┐     ┌───────────────┐     ┌──────────────┐  │
│   │    Queue     │────▶│  Scale        │────▶│   Workers    │  │
│   │   (Redis)    │     │  Controller   │     │  (Clore.ai)  │  │
│   └──────────────┘     └───────────────┘     └──────────────┘  │
│          │                    │                    │            │
│          │                    ▼                    │            │
│          │             ┌───────────────┐           │            │
│          │             │  Metrics      │           │            │
│          │             │  • Queue depth│           │            │
│          │             │  • Workers    │           │            │
│          │             │  • Cost       │           │            │
│          │             └───────────────┘           │            │
│          │                                         │            │
│          └─────────────────────────────────────────┘            │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘
```

## Full Script: Auto-Scaling GPU Workers

```python
#!/usr/bin/env python3
"""
Auto-Scaling GPU Workers for Clore.ai

Dynamically scales GPU workers based on queue depth.

Usage:
    python auto_scaler.py --api-key YOUR_API_KEY --redis-url redis://localhost:6379 \
        --queue inference_queue --min-workers 0 --max-workers 5
"""

import argparse
import json
import os
import sys
import time
import secrets
import threading
import requests
import redis
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum


class WorkerStatus(Enum):
    STARTING = "starting"
    RUNNING = "running"
    DRAINING = "draining"  # Finishing current work before shutdown
    STOPPING = "stopping"
    STOPPED = "stopped"


@dataclass
class Worker:
    """Represents a GPU worker."""
    id: str
    order_id: int
    ssh_host: str
    ssh_port: int
    ssh_password: str
    gpu_type: str
    hourly_cost: float
    status: WorkerStatus = WorkerStatus.STARTING
    started_at: datetime = field(default_factory=datetime.now)
    jobs_processed: int = 0
    current_job: Optional[str] = None


@dataclass
class ScalingConfig:
    """Scaling configuration."""
    min_workers: int = 0
    max_workers: int = 10
    target_queue_per_worker: int = 5  # Target jobs per worker
    scale_up_threshold: int = 10      # Queue depth to trigger scale up
    scale_down_threshold: int = 2     # Queue depth to trigger scale down
    scale_up_cooldown: int = 60       # Seconds between scale up
    scale_down_cooldown: int = 300    # Seconds between scale down
    gpu_type: str = "RTX 3080"
    max_price_usd: float = 0.30
    docker_image: str = "nvidia/cuda:12.8.0-base-ubuntu22.04"
    max_hourly_budget: float = 10.0   # Max spend per hour


@dataclass
class ScalingMetrics:
    """Current scaling metrics."""
    queue_depth: int = 0
    active_workers: int = 0
    starting_workers: int = 0
    draining_workers: int = 0
    total_jobs_processed: int = 0
    current_hourly_cost: float = 0.0
    last_scale_up: Optional[datetime] = None
    last_scale_down: Optional[datetime] = None


class AutoScaler:
    """Auto-scales GPU workers based on queue depth."""
    
    BASE_URL = "https://api.clore.ai"
    
    def __init__(self, api_key: str, redis_url: str, queue_name: str, config: ScalingConfig):
        self.api_key = api_key
        self.headers = {"auth": api_key}
        self.redis = redis.from_url(redis_url)
        self.queue_name = queue_name
        self.config = config
        
        self.workers: Dict[str, Worker] = {}
        self.metrics = ScalingMetrics()
        self._lock = threading.Lock()
        self._running = False
    
    def _api(self, method: str, endpoint: str, **kwargs) -> Dict:
        """Make API request."""
        url = f"{self.BASE_URL}{endpoint}"
        for attempt in range(3):
            response = requests.request(method, url, headers=self.headers, timeout=30)
            data = response.json()
            if data.get("code") == 5:
                time.sleep(2 ** attempt)
                continue
            if data.get("code") != 0:
                raise Exception(f"API Error: {data}")
            return data
        raise Exception("Max retries")
    
    def _log(self, message: str):
        """Log with timestamp."""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{timestamp}] {message}")
    
    def get_queue_depth(self) -> int:
        """Get current queue depth."""
        return self.redis.llen(self.queue_name)
    
    def _find_gpu(self) -> Optional[Dict]:
        """Find available GPU matching criteria."""
        data = self._api("GET", "/v1/marketplace")
        
        for server in data.get("servers", []):
            if server.get("rented"):
                continue
            
            gpus = server.get("gpu_array", [])
            if not any(self.config.gpu_type.lower() in g.lower() for g in gpus):
                continue
            
            price = server.get("price", {}).get("usd", {}).get("spot")
            if price and price <= self.config.max_price_usd:
                return {"id": server["id"], "gpus": gpus, "price": price}
        
        return None
    
    def _provision_worker(self) -> Optional[Worker]:
        """Provision a new GPU worker."""
        server = self._find_gpu()
        if not server:
            self._log(f"⚠️ No {self.config.gpu_type} available under ${self.config.max_price_usd}/hr")
            return None
        
        worker_id = f"w_{secrets.token_hex(4)}"
        ssh_password = secrets.token_urlsafe(16)
        
        order_data = {
            "renting_server": server["id"],
            "type": "spot",
            "currency": "CLORE-Blockchain",
            "image": self.config.docker_image,
            "ports": {"22": "tcp"},
            "env": {
                "NVIDIA_VISIBLE_DEVICES": "all",
                "WORKER_ID": worker_id,
                "REDIS_URL": os.environ.get("REDIS_URL", "redis://localhost:6379"),
                "QUEUE_NAME": self.queue_name
            },
            "ssh_password": ssh_password,
            "spotprice": server["price"] * 1.1
        }
        
        result = self._api("POST", "/v1/create_order", json=order_data)
        order_id = result["order_id"]
        
        # Wait for server
        for _ in range(90):
            orders = self._api("GET", "/v1/my_orders")["orders"]
            order = next((o for o in orders if o["order_id"] == order_id), None)
            
            if order and order.get("status") == "running":
                conn = order["connection"]["ssh"]
                parts = conn.split()
                ssh_host = parts[1].split("@")[1]
                ssh_port = int(parts[-1]) if "-p" in conn else 22
                
                worker = Worker(
                    id=worker_id,
                    order_id=order_id,
                    ssh_host=ssh_host,
                    ssh_port=ssh_port,
                    ssh_password=ssh_password,
                    gpu_type=self.config.gpu_type,
                    hourly_cost=server["price"],
                    status=WorkerStatus.RUNNING
                )
                
                return worker
            
            time.sleep(2)
        
        # Timeout - cancel
        self._api("POST", "/v1/cancel_order", json={"id": order_id})
        return None
    
    def _terminate_worker(self, worker: Worker):
        """Terminate a worker."""
        try:
            self._api("POST", "/v1/cancel_order", json={"id": worker.order_id})
            worker.status = WorkerStatus.STOPPED
        except Exception as e:
            self._log(f"⚠️ Failed to terminate worker {worker.id}: {e}")
    
    def scale_up(self) -> bool:
        """Scale up by one worker."""
        with self._lock:
            # Check cooldown
            if self.metrics.last_scale_up:
                elapsed = (datetime.now() - self.metrics.last_scale_up).total_seconds()
                if elapsed < self.config.scale_up_cooldown:
                    return False
            
            # Check max workers
            active = len([w for w in self.workers.values() if w.status in (WorkerStatus.STARTING, WorkerStatus.RUNNING)])
            if active >= self.config.max_workers:
                return False
            
            # Check budget
            current_cost = sum(w.hourly_cost for w in self.workers.values() if w.status == WorkerStatus.RUNNING)
            if current_cost >= self.config.max_hourly_budget:
                self._log(f"⚠️ Budget limit reached: ${current_cost:.2f}/hr")
                return False
        
        self._log(f"📈 Scaling UP (queue: {self.metrics.queue_depth}, workers: {active})")
        
        worker = self._provision_worker()
        if worker:
            with self._lock:
                self.workers[worker.id] = worker
                self.metrics.last_scale_up = datetime.now()
            self._log(f"✅ Worker {worker.id} started @ ${worker.hourly_cost:.3f}/hr")
            return True
        
        return False
    
    def scale_down(self) -> bool:
        """Scale down by one worker."""
        with self._lock:
            # Check cooldown
            if self.metrics.last_scale_down:
                elapsed = (datetime.now() - self.metrics.last_scale_down).total_seconds()
                if elapsed < self.config.scale_down_cooldown:
                    return False
            
            # Check min workers
            active = [w for w in self.workers.values() if w.status == WorkerStatus.RUNNING]
            if len(active) <= self.config.min_workers:
                return False
            
            # Find worker to terminate (oldest, not processing)
            idle_workers = [w for w in active if not w.current_job]
            if not idle_workers:
                # Mark one for draining
                worker = active[-1]
                worker.status = WorkerStatus.DRAINING
                self._log(f"🔄 Worker {worker.id} draining...")
                return True
            
            # Terminate oldest idle worker
            worker = min(idle_workers, key=lambda w: w.started_at)
        
        self._log(f"📉 Scaling DOWN (queue: {self.metrics.queue_depth}, workers: {len(active)})")
        
        self._terminate_worker(worker)
        
        with self._lock:
            del self.workers[worker.id]
            self.metrics.last_scale_down = datetime.now()
        
        self._log(f"🛑 Worker {worker.id} terminated")
        return True
    
    def _update_metrics(self):
        """Update current metrics."""
        self.metrics.queue_depth = self.get_queue_depth()
        
        with self._lock:
            self.metrics.active_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.RUNNING])
            self.metrics.starting_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.STARTING])
            self.metrics.draining_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.DRAINING])
            self.metrics.current_hourly_cost = sum(w.hourly_cost for w in self.workers.values() if w.status == WorkerStatus.RUNNING)
    
    def _scaling_decision(self):
        """Make scaling decision based on metrics."""
        self._update_metrics()
        
        queue = self.metrics.queue_depth
        workers = self.metrics.active_workers + self.metrics.starting_workers
        
        # Scale up?
        if queue >= self.config.scale_up_threshold:
            # Calculate desired workers
            desired = max(self.config.min_workers, queue // self.config.target_queue_per_worker)
            desired = min(desired, self.config.max_workers)
            
            if workers < desired:
                self.scale_up()
        
        # Scale down?
        elif queue <= self.config.scale_down_threshold and workers > self.config.min_workers:
            self.scale_down()
    
    def _cleanup_draining(self):
        """Clean up draining workers that have finished."""
        with self._lock:
            draining = [w for w in self.workers.values() if w.status == WorkerStatus.DRAINING and not w.current_job]
        
        for worker in draining:
            self._terminate_worker(worker)
            with self._lock:
                del self.workers[worker.id]
            self._log(f"🛑 Drained worker {worker.id} terminated")
    
    def _monitor_loop(self):
        """Main monitoring loop."""
        while self._running:
            try:
                self._scaling_decision()
                self._cleanup_draining()
            except Exception as e:
                self._log(f"❌ Error in monitor loop: {e}")
            
            time.sleep(10)  # Check every 10 seconds
    
    def start(self):
        """Start the auto-scaler."""
        self._running = True
        
        # Ensure minimum workers
        for _ in range(self.config.min_workers):
            self.scale_up()
        
        thread = threading.Thread(target=self._monitor_loop, daemon=True)
        thread.start()
        
        self._log(f"🚀 Auto-scaler started (min: {self.config.min_workers}, max: {self.config.max_workers})")
        
        return thread
    
    def stop(self):
        """Stop the auto-scaler and terminate all workers."""
        self._running = False
        
        self._log("🛑 Stopping auto-scaler...")
        
        # Terminate all workers
        for worker in list(self.workers.values()):
            self._terminate_worker(worker)
        
        self.workers.clear()
        self._log("✅ All workers terminated")
    
    def status(self) -> Dict:
        """Get current status."""
        self._update_metrics()
        
        return {
            "queue_depth": self.metrics.queue_depth,
            "workers": {
                "active": self.metrics.active_workers,
                "starting": self.metrics.starting_workers,
                "draining": self.metrics.draining_workers
            },
            "cost": {
                "hourly": self.metrics.current_hourly_cost,
                "daily_estimate": self.metrics.current_hourly_cost * 24
            },
            "config": asdict(self.config),
            "worker_details": [
                {
                    "id": w.id,
                    "status": w.status.value,
                    "gpu": w.gpu_type,
                    "cost": w.hourly_cost,
                    "jobs": w.jobs_processed,
                    "uptime_min": (datetime.now() - w.started_at).total_seconds() / 60
                }
                for w in self.workers.values()
            ]
        }


def main():
    parser = argparse.ArgumentParser(description="Auto-Scaling GPU Workers")
    parser.add_argument("--api-key", required=True)
    parser.add_argument("--redis-url", default="redis://localhost:6379")
    parser.add_argument("--queue", default="inference_queue")
    parser.add_argument("--min-workers", type=int, default=0)
    parser.add_argument("--max-workers", type=int, default=5)
    parser.add_argument("--gpu", default="RTX 3080")
    parser.add_argument("--max-price", type=float, default=0.30)
    parser.add_argument("--scale-up-threshold", type=int, default=10)
    parser.add_argument("--scale-down-threshold", type=int, default=2)
    parser.add_argument("--max-hourly-budget", type=float, default=10.0)
    parser.add_argument("action", choices=["start", "status", "stop"], nargs="?", default="start")
    args = parser.parse_args()
    
    config = ScalingConfig(
        min_workers=args.min_workers,
        max_workers=args.max_workers,
        gpu_type=args.gpu,
        max_price_usd=args.max_price,
        scale_up_threshold=args.scale_up_threshold,
        scale_down_threshold=args.scale_down_threshold,
        max_hourly_budget=args.max_hourly_budget
    )
    
    scaler = AutoScaler(args.api_key, args.redis_url, args.queue, config)
    
    if args.action == "start":
        try:
            scaler.start()
            print("\nAuto-scaler running. Press Ctrl+C to stop.\n")
            
            while True:
                time.sleep(30)
                status = scaler.status()
                print(f"[Status] Queue: {status['queue_depth']} | "
                      f"Workers: {status['workers']['active']} | "
                      f"Cost: ${status['cost']['hourly']:.2f}/hr")
        except KeyboardInterrupt:
            print("\nShutting down...")
            scaler.stop()
    
    elif args.action == "status":
        status = scaler.status()
        print(json.dumps(status, indent=2, default=str))


if __name__ == "__main__":
    main()
```

## Example Worker Script

```python
#!/usr/bin/env python3
"""
GPU Worker - processes jobs from queue.
"""

import os
import time
import redis
import json

REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
QUEUE_NAME = os.environ.get("QUEUE_NAME", "inference_queue")
WORKER_ID = os.environ.get("WORKER_ID", "worker_1")

r = redis.from_url(REDIS_URL)

print(f"Worker {WORKER_ID} started")

while True:
    # Blocking pop from queue
    job = r.blpop(QUEUE_NAME, timeout=30)
    
    if job:
        _, data = job
        job_data = json.loads(data)
        
        print(f"Processing job: {job_data['id']}")
        
        # Your inference logic here
        # result = model.predict(job_data['input'])
        time.sleep(1)  # Simulate work
        
        # Push result
        result = {"job_id": job_data["id"], "result": "processed"}
        r.lpush(f"results:{job_data['id']}", json.dumps(result))
        
        print(f"Completed job: {job_data['id']}")
```

## Scaling Behavior

| Queue Depth | Workers (0-5 range) | Action            |
| ----------- | ------------------- | ----------------- |
| 0-2         | min                 | Scale down to min |
| 3-9         | 1-2                 | Maintain          |
| 10-19       | 2-3                 | Scale up          |
| 20-49       | 4                   | Scale up          |
| 50+         | 5 (max)             | At capacity       |

## Cost Example

| Queue Load     | Workers | Hourly Cost | Daily Cost |
| -------------- | ------- | ----------- | ---------- |
| Low (0-10)     | 1       | $0.30       | $7.20      |
| Medium (10-30) | 2-3     | $0.60-0.90  | $14-22     |
| High (30+)     | 5       | $1.50       | $36.00     |

## Next Steps

* [Batch Inference](https://docs.clore.ai/dev/inference-and-deployment/batch-inference)
* [Model Router](https://docs.clore.ai/dev/inference-and-deployment/model-router)
* [REST API Deployment](https://docs.clore.ai/dev/inference-and-deployment/rest-api-deployment)
