Copy #!/usr/bin/env python3
"""
Batch Inference Pipeline for Clore.ai
Process millions of images across multiple GPU workers.
Usage:
python batch_inference.py --api-key YOUR_API_KEY --input ./images/ --output ./results/ \
--model resnet50 --workers 3 --batch-size 32
"""
import argparse
import json
import os
import sys
import time
import secrets
import threading
import queue
import requests
import paramiko
from scp import SCPClient
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@dataclass
class Worker:
"""GPU worker for inference."""
id: str
order_id: int
ssh_host: str
ssh_port: int
ssh_password: str
gpu_model: str
hourly_cost: float
status: str = "starting"
images_processed: int = 0
_ssh: paramiko.SSHClient = None
_scp: SCPClient = None
@dataclass
class BatchJob:
"""A batch of images to process."""
id: str
images: List[str]
worker_id: Optional[str] = None
status: str = "pending"
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
results: List[Dict] = field(default_factory=list)
error: Optional[str] = None
@dataclass
class PipelineStats:
"""Pipeline statistics."""
total_images: int = 0
processed_images: int = 0
failed_images: int = 0
total_batches: int = 0
completed_batches: int = 0
failed_batches: int = 0
start_time: datetime = field(default_factory=datetime.now)
total_cost: float = 0.0
@property
def progress_percent(self) -> float:
if self.total_images == 0:
return 0
return (self.processed_images / self.total_images) * 100
@property
def elapsed_seconds(self) -> float:
return (datetime.now() - self.start_time).total_seconds()
@property
def images_per_second(self) -> float:
if self.elapsed_seconds == 0:
return 0
return self.processed_images / self.elapsed_seconds
class BatchInferencePipeline:
"""Process images in batches across multiple GPU workers."""
BASE_URL = "https://api.clore.ai"
INFERENCE_IMAGE = "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime"
def __init__(self, api_key: str, model_name: str = "resnet50"):
self.api_key = api_key
self.headers = {"auth": api_key}
self.model_name = model_name
self.workers: Dict[str, Worker] = {}
self.job_queue: queue.Queue = queue.Queue()
self.result_queue: queue.Queue = queue.Queue()
self.stats = PipelineStats()
self._lock = threading.Lock()
self._running = False
def _api(self, method: str, endpoint: str, **kwargs) -> Dict:
"""Make API request."""
url = f"{self.BASE_URL}{endpoint}"
for attempt in range(3):
response = requests.request(method, url, headers=self.headers, timeout=30)
data = response.json()
if data.get("code") == 5:
time.sleep(2 ** attempt)
continue
if data.get("code") != 0:
raise Exception(f"API Error: {data}")
return data
raise Exception("Max retries")
def _log(self, message: str):
"""Log with timestamp."""
ts = datetime.now().strftime("%H:%M:%S")
print(f"[{ts}] {message}")
def provision_workers(self, count: int, gpu_type: str = "RTX 3080", max_price: float = 0.30) -> int:
"""Provision GPU workers."""
self._log(f"🔍 Provisioning {count} workers ({gpu_type}, max ${max_price}/hr)...")
servers = self._api("GET", "/v1/marketplace")["servers"]
# Find available servers
candidates = []
for s in servers:
if s.get("rented"):
continue
gpus = s.get("gpu_array", [])
if not any(gpu_type.lower() in g.lower() for g in gpus):
continue
price = s.get("price", {}).get("usd", {}).get("spot")
if price and price <= max_price:
candidates.append({"id": s["id"], "gpus": gpus, "price": price})
if len(candidates) < count:
self._log(f"⚠️ Only {len(candidates)} workers available (requested {count})")
# Provision workers
provisioned = 0
for i, server in enumerate(candidates[:count]):
worker_id = f"w_{i:02d}"
ssh_password = secrets.token_urlsafe(16)
order_data = {
"renting_server": server["id"],
"type": "spot",
"currency": "CLORE-Blockchain",
"image": self.INFERENCE_IMAGE,
"ports": {"22": "tcp"},
"env": {"NVIDIA_VISIBLE_DEVICES": "all"},
"ssh_password": ssh_password,
"spotprice": server["price"] * 1.1
}
try:
result = self._api("POST", "/v1/create_order", json=order_data)
order_id = result["order_id"]
# Wait for server
for _ in range(90):
orders = self._api("GET", "/v1/my_orders")["orders"]
order = next((o for o in orders if o["order_id"] == order_id), None)
if order and order.get("status") == "running":
conn = order["connection"]["ssh"]
parts = conn.split()
ssh_host = parts[1].split("@")[1]
ssh_port = int(parts[-1]) if "-p" in conn else 22
worker = Worker(
id=worker_id,
order_id=order_id,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
gpu_model=server["gpus"][0] if server["gpus"] else "GPU",
hourly_cost=server["price"]
)
self.workers[worker_id] = worker
provisioned += 1
self._log(f"✅ Worker {worker_id} ready @ ${server['price']:.3f}/hr")
break
time.sleep(2)
except Exception as e:
self._log(f"❌ Failed to provision worker: {e}")
return provisioned
def setup_workers(self):
"""Set up inference environment on all workers."""
self._log("📦 Setting up workers...")
setup_script = f"""
pip install -q torch torchvision pillow
python3 << 'EOF'
import torch
from torchvision import models, transforms
from PIL import Image
import json
# Load model
model = models.{self.model_name}(pretrained=True)
model = model.cuda().eval()
# Save for later use
torch.save(model.state_dict(), '/tmp/model.pt')
print("Model loaded and cached")
EOF
"""
for worker in self.workers.values():
try:
# Connect SSH
worker._ssh = paramiko.SSHClient()
worker._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
worker._ssh.connect(
worker.ssh_host,
port=worker.ssh_port,
username="root",
password=worker.ssh_password,
timeout=30
)
worker._scp = SCPClient(worker._ssh.get_transport())
# Run setup
stdin, stdout, stderr = worker._ssh.exec_command(setup_script, timeout=300)
stdout.channel.recv_exit_status()
worker.status = "ready"
self._log(f" {worker.id}: Ready ({worker.gpu_model})")
except Exception as e:
worker.status = "failed"
self._log(f" {worker.id}: Setup failed - {e}")
def create_batches(self, image_paths: List[str], batch_size: int) -> List[BatchJob]:
"""Create batch jobs from image paths."""
batches = []
for i in range(0, len(image_paths), batch_size):
batch_images = image_paths[i:i + batch_size]
batch = BatchJob(
id=f"batch_{i // batch_size:06d}",
images=batch_images
)
batches.append(batch)
self.stats.total_images = len(image_paths)
self.stats.total_batches = len(batches)
return batches
def process_batch_on_worker(self, worker: Worker, batch: BatchJob, output_dir: str) -> BatchJob:
"""Process a batch on a specific worker."""
batch.worker_id = worker.id
batch.status = "processing"
batch.start_time = datetime.now()
try:
# Create remote directories
worker._ssh.exec_command("mkdir -p /tmp/input /tmp/output")
time.sleep(0.5)
# Upload images
for img_path in batch.images:
try:
worker._scp.put(img_path, f"/tmp/input/{os.path.basename(img_path)}")
except Exception as e:
batch.results.append({"image": img_path, "error": str(e)})
# Run inference
inference_script = f"""
import torch
from torchvision import models, transforms
from PIL import Image
import json
import os
model = models.{self.model_name}(pretrained=True)
model = model.cuda().eval()
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
results = []
for img_file in os.listdir('/tmp/input'):
try:
img_path = f'/tmp/input/{{img_file}}'
img = Image.open(img_path).convert('RGB')
input_tensor = transform(img).unsqueeze(0).cuda()
with torch.no_grad():
output = model(input_tensor)
pred = torch.argmax(output, dim=1).item()
conf = torch.softmax(output, dim=1).max().item()
results.append({{"image": img_file, "prediction": pred, "confidence": conf}})
except Exception as e:
results.append({{"image": img_file, "error": str(e)}})
print("RESULTS:" + json.dumps(results))
# Cleanup
import shutil
shutil.rmtree('/tmp/input')
os.makedirs('/tmp/input')
"""
stdin, stdout, stderr = worker._ssh.exec_command(
f"python3 -c '{inference_script}'",
timeout=300
)
output = stdout.read().decode()
# Parse results
for line in output.split("\n"):
if line.startswith("RESULTS:"):
batch.results = json.loads(line[8:])
break
batch.status = "completed"
batch.end_time = datetime.now()
# Update stats
with self._lock:
worker.images_processed += len(batch.images)
self.stats.processed_images += len([r for r in batch.results if "error" not in r])
self.stats.failed_images += len([r for r in batch.results if "error" in r])
self.stats.completed_batches += 1
except Exception as e:
batch.status = "failed"
batch.error = str(e)
batch.end_time = datetime.now()
with self._lock:
self.stats.failed_batches += 1
return batch
def run(self, input_dir: str, output_dir: str, batch_size: int = 32, num_workers: int = 3,
gpu_type: str = "RTX 3080", max_price: float = 0.30) -> PipelineStats:
"""Run the batch inference pipeline."""
self.stats = PipelineStats()
os.makedirs(output_dir, exist_ok=True)
# Collect images
self._log("📂 Collecting images...")
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
image_paths = []
for root, _, files in os.walk(input_dir):
for f in files:
if Path(f).suffix.lower() in image_extensions:
image_paths.append(os.path.join(root, f))
self._log(f" Found {len(image_paths):,} images")
if not image_paths:
self._log("❌ No images found")
return self.stats
# Create batches
batches = self.create_batches(image_paths, batch_size)
self._log(f" Created {len(batches):,} batches")
# Provision workers
provisioned = self.provision_workers(num_workers, gpu_type, max_price)
if provisioned == 0:
self._log("❌ No workers available")
return self.stats
# Setup workers
self.setup_workers()
ready_workers = [w for w in self.workers.values() if w.status == "ready"]
if not ready_workers:
self._log("❌ No workers ready")
self.cleanup()
return self.stats
# Process batches
self._log(f"\n🚀 Starting inference with {len(ready_workers)} workers...")
self._running = True
all_results = []
with tqdm(total=len(batches), desc="Processing") as pbar:
with ThreadPoolExecutor(max_workers=len(ready_workers)) as executor:
# Submit initial batches
futures = {}
batch_iter = iter(batches)
for worker in ready_workers:
try:
batch = next(batch_iter)
future = executor.submit(self.process_batch_on_worker, worker, batch, output_dir)
futures[future] = (worker, batch)
except StopIteration:
break
# Process results and submit new batches
while futures:
done_futures = [f for f in futures if f.done()]
for future in done_futures:
worker, batch = futures.pop(future)
try:
result = future.result()
all_results.extend(result.results)
pbar.update(1)
# Submit next batch to this worker
try:
next_batch = next(batch_iter)
new_future = executor.submit(self.process_batch_on_worker, worker, next_batch, output_dir)
futures[new_future] = (worker, next_batch)
except StopIteration:
pass
except Exception as e:
self._log(f"❌ Batch failed: {e}")
pbar.update(1)
time.sleep(0.1)
# Save results
results_file = os.path.join(output_dir, "results.json")
with open(results_file, "w") as f:
json.dump(all_results, f, indent=2)
# Calculate cost
for worker in self.workers.values():
runtime_hours = (datetime.now() - self.stats.start_time).total_seconds() / 3600
self.stats.total_cost += runtime_hours * worker.hourly_cost
self._log(f"\n{'='*60}")
self._log("📊 BATCH INFERENCE COMPLETE")
self._log(f"{'='*60}")
self._log(f" Images: {self.stats.processed_images:,}/{self.stats.total_images:,}")
self._log(f" Failed: {self.stats.failed_images:,}")
self._log(f" Time: {self.stats.elapsed_seconds:.1f}s")
self._log(f" Speed: {self.stats.images_per_second:.1f} images/sec")
self._log(f" Cost: ${self.stats.total_cost:.4f}")
self._log(f" Results: {results_file}")
# Cleanup
self.cleanup()
return self.stats
def cleanup(self):
"""Terminate all workers."""
self._log("\n🧹 Cleaning up workers...")
for worker in self.workers.values():
try:
if worker._scp:
worker._scp.close()
if worker._ssh:
worker._ssh.close()
self._api("POST", "/v1/cancel_order", json={"id": worker.order_id})
self._log(f" {worker.id}: Terminated")
except Exception as e:
self._log(f" {worker.id}: Cleanup failed - {e}")
self.workers.clear()
def main():
parser = argparse.ArgumentParser(description="Batch Inference Pipeline")
parser.add_argument("--api-key", required=True)
parser.add_argument("--input", "-i", required=True, help="Input directory")
parser.add_argument("--output", "-o", required=True, help="Output directory")
parser.add_argument("--model", default="resnet50", help="Model name")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--workers", type=int, default=3)
parser.add_argument("--gpu", default="RTX 3080")
parser.add_argument("--max-price", type=float, default=0.30)
args = parser.parse_args()
pipeline = BatchInferencePipeline(args.api_key, args.model)
stats = pipeline.run(
input_dir=args.input,
output_dir=args.output,
batch_size=args.batch_size,
num_workers=args.workers,
gpu_type=args.gpu,
max_price=args.max_price
)
# Exit with error if too many failures
if stats.failed_images > stats.total_images * 0.1:
sys.exit(1)
if __name__ == "__main__":
main()