Copy #!/usr/bin/env python3
"""
Auto-Scaling GPU Workers for Clore.ai
Dynamically scales GPU workers based on queue depth.
Usage:
python auto_scaler.py --api-key YOUR_API_KEY --redis-url redis://localhost:6379 \
--queue inference_queue --min-workers 0 --max-workers 5
"""
import argparse
import json
import os
import sys
import time
import secrets
import threading
import requests
import redis
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
class WorkerStatus(Enum):
STARTING = "starting"
RUNNING = "running"
DRAINING = "draining" # Finishing current work before shutdown
STOPPING = "stopping"
STOPPED = "stopped"
@dataclass
class Worker:
"""Represents a GPU worker."""
id: str
order_id: int
ssh_host: str
ssh_port: int
ssh_password: str
gpu_type: str
hourly_cost: float
status: WorkerStatus = WorkerStatus.STARTING
started_at: datetime = field(default_factory=datetime.now)
jobs_processed: int = 0
current_job: Optional[str] = None
@dataclass
class ScalingConfig:
"""Scaling configuration."""
min_workers: int = 0
max_workers: int = 10
target_queue_per_worker: int = 5 # Target jobs per worker
scale_up_threshold: int = 10 # Queue depth to trigger scale up
scale_down_threshold: int = 2 # Queue depth to trigger scale down
scale_up_cooldown: int = 60 # Seconds between scale up
scale_down_cooldown: int = 300 # Seconds between scale down
gpu_type: str = "RTX 3080"
max_price_usd: float = 0.30
docker_image: str = "nvidia/cuda:12.8.0-base-ubuntu22.04"
max_hourly_budget: float = 10.0 # Max spend per hour
@dataclass
class ScalingMetrics:
"""Current scaling metrics."""
queue_depth: int = 0
active_workers: int = 0
starting_workers: int = 0
draining_workers: int = 0
total_jobs_processed: int = 0
current_hourly_cost: float = 0.0
last_scale_up: Optional[datetime] = None
last_scale_down: Optional[datetime] = None
class AutoScaler:
"""Auto-scales GPU workers based on queue depth."""
BASE_URL = "https://api.clore.ai"
def __init__(self, api_key: str, redis_url: str, queue_name: str, config: ScalingConfig):
self.api_key = api_key
self.headers = {"auth": api_key}
self.redis = redis.from_url(redis_url)
self.queue_name = queue_name
self.config = config
self.workers: Dict[str, Worker] = {}
self.metrics = ScalingMetrics()
self._lock = threading.Lock()
self._running = False
def _api(self, method: str, endpoint: str, **kwargs) -> Dict:
"""Make API request."""
url = f"{self.BASE_URL}{endpoint}"
for attempt in range(3):
response = requests.request(method, url, headers=self.headers, timeout=30)
data = response.json()
if data.get("code") == 5:
time.sleep(2 ** attempt)
continue
if data.get("code") != 0:
raise Exception(f"API Error: {data}")
return data
raise Exception("Max retries")
def _log(self, message: str):
"""Log with timestamp."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] {message}")
def get_queue_depth(self) -> int:
"""Get current queue depth."""
return self.redis.llen(self.queue_name)
def _find_gpu(self) -> Optional[Dict]:
"""Find available GPU matching criteria."""
data = self._api("GET", "/v1/marketplace")
for server in data.get("servers", []):
if server.get("rented"):
continue
gpus = server.get("gpu_array", [])
if not any(self.config.gpu_type.lower() in g.lower() for g in gpus):
continue
price = server.get("price", {}).get("usd", {}).get("spot")
if price and price <= self.config.max_price_usd:
return {"id": server["id"], "gpus": gpus, "price": price}
return None
def _provision_worker(self) -> Optional[Worker]:
"""Provision a new GPU worker."""
server = self._find_gpu()
if not server:
self._log(f"β οΈ No {self.config.gpu_type} available under ${self.config.max_price_usd}/hr")
return None
worker_id = f"w_{secrets.token_hex(4)}"
ssh_password = secrets.token_urlsafe(16)
order_data = {
"renting_server": server["id"],
"type": "spot",
"currency": "CLORE-Blockchain",
"image": self.config.docker_image,
"ports": {"22": "tcp"},
"env": {
"NVIDIA_VISIBLE_DEVICES": "all",
"WORKER_ID": worker_id,
"REDIS_URL": os.environ.get("REDIS_URL", "redis://localhost:6379"),
"QUEUE_NAME": self.queue_name
},
"ssh_password": ssh_password,
"spotprice": server["price"] * 1.1
}
result = self._api("POST", "/v1/create_order", json=order_data)
order_id = result["order_id"]
# Wait for server
for _ in range(90):
orders = self._api("GET", "/v1/my_orders")["orders"]
order = next((o for o in orders if o["order_id"] == order_id), None)
if order and order.get("status") == "running":
conn = order["connection"]["ssh"]
parts = conn.split()
ssh_host = parts[1].split("@")[1]
ssh_port = int(parts[-1]) if "-p" in conn else 22
worker = Worker(
id=worker_id,
order_id=order_id,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
gpu_type=self.config.gpu_type,
hourly_cost=server["price"],
status=WorkerStatus.RUNNING
)
return worker
time.sleep(2)
# Timeout - cancel
self._api("POST", "/v1/cancel_order", json={"id": order_id})
return None
def _terminate_worker(self, worker: Worker):
"""Terminate a worker."""
try:
self._api("POST", "/v1/cancel_order", json={"id": worker.order_id})
worker.status = WorkerStatus.STOPPED
except Exception as e:
self._log(f"β οΈ Failed to terminate worker {worker.id}: {e}")
def scale_up(self) -> bool:
"""Scale up by one worker."""
with self._lock:
# Check cooldown
if self.metrics.last_scale_up:
elapsed = (datetime.now() - self.metrics.last_scale_up).total_seconds()
if elapsed < self.config.scale_up_cooldown:
return False
# Check max workers
active = len([w for w in self.workers.values() if w.status in (WorkerStatus.STARTING, WorkerStatus.RUNNING)])
if active >= self.config.max_workers:
return False
# Check budget
current_cost = sum(w.hourly_cost for w in self.workers.values() if w.status == WorkerStatus.RUNNING)
if current_cost >= self.config.max_hourly_budget:
self._log(f"β οΈ Budget limit reached: ${current_cost:.2f}/hr")
return False
self._log(f"π Scaling UP (queue: {self.metrics.queue_depth}, workers: {active})")
worker = self._provision_worker()
if worker:
with self._lock:
self.workers[worker.id] = worker
self.metrics.last_scale_up = datetime.now()
self._log(f"β
Worker {worker.id} started @ ${worker.hourly_cost:.3f}/hr")
return True
return False
def scale_down(self) -> bool:
"""Scale down by one worker."""
with self._lock:
# Check cooldown
if self.metrics.last_scale_down:
elapsed = (datetime.now() - self.metrics.last_scale_down).total_seconds()
if elapsed < self.config.scale_down_cooldown:
return False
# Check min workers
active = [w for w in self.workers.values() if w.status == WorkerStatus.RUNNING]
if len(active) <= self.config.min_workers:
return False
# Find worker to terminate (oldest, not processing)
idle_workers = [w for w in active if not w.current_job]
if not idle_workers:
# Mark one for draining
worker = active[-1]
worker.status = WorkerStatus.DRAINING
self._log(f"π Worker {worker.id} draining...")
return True
# Terminate oldest idle worker
worker = min(idle_workers, key=lambda w: w.started_at)
self._log(f"π Scaling DOWN (queue: {self.metrics.queue_depth}, workers: {len(active)})")
self._terminate_worker(worker)
with self._lock:
del self.workers[worker.id]
self.metrics.last_scale_down = datetime.now()
self._log(f"π Worker {worker.id} terminated")
return True
def _update_metrics(self):
"""Update current metrics."""
self.metrics.queue_depth = self.get_queue_depth()
with self._lock:
self.metrics.active_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.RUNNING])
self.metrics.starting_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.STARTING])
self.metrics.draining_workers = len([w for w in self.workers.values() if w.status == WorkerStatus.DRAINING])
self.metrics.current_hourly_cost = sum(w.hourly_cost for w in self.workers.values() if w.status == WorkerStatus.RUNNING)
def _scaling_decision(self):
"""Make scaling decision based on metrics."""
self._update_metrics()
queue = self.metrics.queue_depth
workers = self.metrics.active_workers + self.metrics.starting_workers
# Scale up?
if queue >= self.config.scale_up_threshold:
# Calculate desired workers
desired = max(self.config.min_workers, queue // self.config.target_queue_per_worker)
desired = min(desired, self.config.max_workers)
if workers < desired:
self.scale_up()
# Scale down?
elif queue <= self.config.scale_down_threshold and workers > self.config.min_workers:
self.scale_down()
def _cleanup_draining(self):
"""Clean up draining workers that have finished."""
with self._lock:
draining = [w for w in self.workers.values() if w.status == WorkerStatus.DRAINING and not w.current_job]
for worker in draining:
self._terminate_worker(worker)
with self._lock:
del self.workers[worker.id]
self._log(f"π Drained worker {worker.id} terminated")
def _monitor_loop(self):
"""Main monitoring loop."""
while self._running:
try:
self._scaling_decision()
self._cleanup_draining()
except Exception as e:
self._log(f"β Error in monitor loop: {e}")
time.sleep(10) # Check every 10 seconds
def start(self):
"""Start the auto-scaler."""
self._running = True
# Ensure minimum workers
for _ in range(self.config.min_workers):
self.scale_up()
thread = threading.Thread(target=self._monitor_loop, daemon=True)
thread.start()
self._log(f"π Auto-scaler started (min: {self.config.min_workers}, max: {self.config.max_workers})")
return thread
def stop(self):
"""Stop the auto-scaler and terminate all workers."""
self._running = False
self._log("π Stopping auto-scaler...")
# Terminate all workers
for worker in list(self.workers.values()):
self._terminate_worker(worker)
self.workers.clear()
self._log("β
All workers terminated")
def status(self) -> Dict:
"""Get current status."""
self._update_metrics()
return {
"queue_depth": self.metrics.queue_depth,
"workers": {
"active": self.metrics.active_workers,
"starting": self.metrics.starting_workers,
"draining": self.metrics.draining_workers
},
"cost": {
"hourly": self.metrics.current_hourly_cost,
"daily_estimate": self.metrics.current_hourly_cost * 24
},
"config": asdict(self.config),
"worker_details": [
{
"id": w.id,
"status": w.status.value,
"gpu": w.gpu_type,
"cost": w.hourly_cost,
"jobs": w.jobs_processed,
"uptime_min": (datetime.now() - w.started_at).total_seconds() / 60
}
for w in self.workers.values()
]
}
def main():
parser = argparse.ArgumentParser(description="Auto-Scaling GPU Workers")
parser.add_argument("--api-key", required=True)
parser.add_argument("--redis-url", default="redis://localhost:6379")
parser.add_argument("--queue", default="inference_queue")
parser.add_argument("--min-workers", type=int, default=0)
parser.add_argument("--max-workers", type=int, default=5)
parser.add_argument("--gpu", default="RTX 3080")
parser.add_argument("--max-price", type=float, default=0.30)
parser.add_argument("--scale-up-threshold", type=int, default=10)
parser.add_argument("--scale-down-threshold", type=int, default=2)
parser.add_argument("--max-hourly-budget", type=float, default=10.0)
parser.add_argument("action", choices=["start", "status", "stop"], nargs="?", default="start")
args = parser.parse_args()
config = ScalingConfig(
min_workers=args.min_workers,
max_workers=args.max_workers,
gpu_type=args.gpu,
max_price_usd=args.max_price,
scale_up_threshold=args.scale_up_threshold,
scale_down_threshold=args.scale_down_threshold,
max_hourly_budget=args.max_hourly_budget
)
scaler = AutoScaler(args.api_key, args.redis_url, args.queue, config)
if args.action == "start":
try:
scaler.start()
print("\nAuto-scaler running. Press Ctrl+C to stop.\n")
while True:
time.sleep(30)
status = scaler.status()
print(f"[Status] Queue: {status['queue_depth']} | "
f"Workers: {status['workers']['active']} | "
f"Cost: ${status['cost']['hourly']:.2f}/hr")
except KeyboardInterrupt:
print("\nShutting down...")
scaler.stop()
elif args.action == "status":
status = scaler.status()
print(json.dumps(status, indent=2, default=str))
if __name__ == "__main__":
main()