A unified GPU orchestration layer that spans Clore.ai and other cloud providers, automatically selecting the most cost-effective option for your workloads. Route jobs to the cheapest available GPUs across multiple providers with a single API.
Key Features:
Unified API across Clore.ai, AWS, GCP, Azure, and Lambda Labs
Automatic cost optimization and provider selection
# orchestrator.py
from typing import List, Dict, Optional
from dataclasses import dataclass
import logging
from providers.base import (
CloudProvider, GPUInstance, ProviderCapacity,
JobRequest, JobResult, GPUType
)
from providers.clore import CloreProvider
from providers.aws import AWSProvider
logger = logging.getLogger(__name__)
@dataclass
class PriceComparison:
"""Price comparison across providers."""
gpu_type: GPUType
providers: List[Dict] # [{provider, price, available, is_spot}]
cheapest_provider: str
cheapest_price: float
class MultiCloudOrchestrator:
"""Orchestrate GPU jobs across multiple cloud providers."""
def __init__(self):
self.providers: Dict[str, CloudProvider] = {}
self._active_jobs: Dict[str, str] = {} # job_id -> provider
def add_provider(self, provider: CloudProvider):
"""Add a cloud provider."""
self.providers[provider.name] = provider
logger.info(f"Added provider: {provider.name}")
def remove_provider(self, name: str):
"""Remove a cloud provider."""
if name in self.providers:
del self.providers[name]
def get_all_capacity(self) -> List[ProviderCapacity]:
"""Get capacity from all providers."""
all_capacity = []
for name, provider in self.providers.items():
try:
capacity = provider.get_available_gpus()
all_capacity.extend(capacity)
except Exception as e:
logger.warning(f"Failed to get capacity from {name}: {e}")
return all_capacity
def compare_prices(self, gpu_type: GPUType) -> PriceComparison:
"""Compare prices across all providers for a GPU type."""
capacity = self.get_all_capacity()
matching = [c for c in capacity if c.gpu_type == gpu_type]
providers = []
for c in matching:
providers.append({
"provider": c.provider,
"price": c.price_per_hour,
"available": c.available_count,
"is_spot": c.is_spot,
"region": c.region
})
# Sort by price
providers.sort(key=lambda x: x["price"])
return PriceComparison(
gpu_type=gpu_type,
providers=providers,
cheapest_provider=providers[0]["provider"] if providers else "",
cheapest_price=providers[0]["price"] if providers else float('inf')
)
def select_provider(self, request: JobRequest) -> Optional[str]:
"""Select best provider for a job request."""
comparison = self.compare_prices(request.gpu_type)
for p in comparison.providers:
# Check price limit
if request.max_price_per_hour and p["price"] > request.max_price_per_hour:
continue
# Check spot preference
if request.prefer_spot and not p["is_spot"]:
# Still consider if much cheaper
if p["price"] > comparison.cheapest_price * 1.5:
continue
# Check availability
if p["available"] < request.gpu_count:
continue
return p["provider"]
return None
def submit_job(self, request: JobRequest, provider: str = None) -> JobResult:
"""Submit a job, optionally to a specific provider."""
# Select provider if not specified
if not provider:
provider = self.select_provider(request)
if not provider:
return JobResult(
success=False,
job_id="",
provider="",
error=f"No provider available for {request.gpu_type.value}"
)
if provider not in self.providers:
return JobResult(
success=False,
job_id="",
provider=provider,
error=f"Provider {provider} not configured"
)
logger.info(f"Submitting job to {provider}: {request.gpu_type.value}")
# Launch on selected provider
result = self.providers[provider].launch_instance(request)
if result.success:
self._active_jobs[result.job_id] = provider
return result
def submit_with_failover(self, request: JobRequest) -> JobResult:
"""Submit job with automatic failover to other providers."""
comparison = self.compare_prices(request.gpu_type)
for p in comparison.providers:
if request.max_price_per_hour and p["price"] > request.max_price_per_hour:
continue
provider_name = p["provider"]
if provider_name not in self.providers:
continue
logger.info(f"Trying provider {provider_name}...")
result = self.providers[provider_name].launch_instance(request)
if result.success:
self._active_jobs[result.job_id] = provider_name
return result
logger.warning(f"Provider {provider_name} failed: {result.error}")
return JobResult(
success=False,
job_id="",
provider="",
error="All providers failed"
)
def terminate_job(self, job_id: str) -> bool:
"""Terminate a job."""
provider_name = self._active_jobs.get(job_id)
if not provider_name:
logger.warning(f"Unknown job: {job_id}")
return False
provider = self.providers.get(provider_name)
if not provider:
return False
success = provider.terminate_instance(job_id)
if success:
del self._active_jobs[job_id]
return success
def get_job_status(self, job_id: str) -> Optional[GPUInstance]:
"""Get status of a job."""
provider_name = self._active_jobs.get(job_id)
if not provider_name:
return None
provider = self.providers.get(provider_name)
if not provider:
return None
return provider.get_instance_status(job_id)
def list_all_jobs(self) -> List[GPUInstance]:
"""List all jobs across all providers."""
all_jobs = []
for name, provider in self.providers.items():
try:
jobs = provider.list_instances()
all_jobs.extend(jobs)
except Exception as e:
logger.warning(f"Failed to list jobs from {name}: {e}")
return all_jobs
def terminate_all_jobs(self) -> int:
"""Terminate all active jobs."""
count = 0
for job_id in list(self._active_jobs.keys()):
if self.terminate_job(job_id):
count += 1
return count
def get_total_cost_per_hour(self) -> float:
"""Get total cost per hour of all active jobs."""
jobs = self.list_all_jobs()
return sum(j.price_per_hour for j in jobs if j.status == "running")
#!/usr/bin/env python3
"""
Multi-Cloud GPU Orchestrator
Usage:
python multi_cloud.py --action compare --gpu RTX_4090
python multi_cloud.py --action submit --gpu A100_40GB --max-price 2.0
python multi_cloud.py --action list
python multi_cloud.py --action terminate --job-id abc123
"""
import argparse
import json
from orchestrator import MultiCloudOrchestrator, GPUType, JobRequest
from providers.clore import CloreProvider
from providers.aws import AWSProvider
def setup_orchestrator(clore_key: str = None) -> MultiCloudOrchestrator:
"""Set up orchestrator with all providers."""
orch = MultiCloudOrchestrator()
# Add Clore.ai
if clore_key:
orch.add_provider(CloreProvider(clore_key))
# Add AWS (if credentials configured)
try:
orch.add_provider(AWSProvider())
except Exception:
pass
return orch
def main():
parser = argparse.ArgumentParser(description="Multi-Cloud GPU Orchestrator")
parser.add_argument("--action", required=True,
choices=["compare", "submit", "list", "terminate", "status"])
parser.add_argument("--gpu", help="GPU type (e.g., RTX_4090, A100_40GB)")
parser.add_argument("--max-price", type=float, help="Max price per hour")
parser.add_argument("--provider", help="Specific provider to use")
parser.add_argument("--job-id", help="Job ID for status/terminate")
parser.add_argument("--clore-key", help="Clore.ai API key")
parser.add_argument("--image", default="nvidia/cuda:12.8.0-base-ubuntu22.04")
args = parser.parse_args()
orch = setup_orchestrator(args.clore_key)
if not orch.providers:
print("β No providers configured!")
return
print(f"β Providers: {list(orch.providers.keys())}")
print()
if args.action == "compare":
if not args.gpu:
print("--gpu required for compare")
return
gpu_type = GPUType[args.gpu.upper()]
comparison = orch.compare_prices(gpu_type)
print(f"π° Price Comparison: {gpu_type.value}")
print("-" * 60)
for p in comparison.providers:
spot_label = "π’ Spot" if p["is_spot"] else "π΅ On-Demand"
print(f" {p['provider']:10} ${p['price']:.3f}/hr {p['available']:3} avail {spot_label}")
print("-" * 60)
print(f"π Cheapest: {comparison.cheapest_provider} @ ${comparison.cheapest_price:.3f}/hr")
elif args.action == "submit":
if not args.gpu:
print("--gpu required for submit")
return
gpu_type = GPUType[args.gpu.upper()]
request = JobRequest(
gpu_type=gpu_type,
gpu_count=1,
image=args.image,
max_price_per_hour=args.max_price,
prefer_spot=True
)
if args.provider:
result = orch.submit_job(request, provider=args.provider)
else:
result = orch.submit_with_failover(request)
if result.success:
print(f"β Job submitted successfully!")
print(f" Job ID: {result.job_id}")
print(f" Provider: {result.provider}")
if result.instance:
print(f" SSH: {result.instance.ssh_user}@{result.instance.ssh_host}:{result.instance.ssh_port}")
print(f" Price: ${result.instance.price_per_hour:.3f}/hr")
else:
print(f"β Job failed: {result.error}")
elif args.action == "list":
jobs = orch.list_all_jobs()
if not jobs:
print("No active jobs")
return
print(f"π Active Jobs ({len(jobs)})")
print("-" * 70)
total_cost = 0
for job in jobs:
print(f" {job.instance_id:15} {job.provider:8} {job.gpu_type.value:12} "
f"${job.price_per_hour:.3f}/hr {job.status}")
if job.status == "running":
total_cost += job.price_per_hour
print("-" * 70)
print(f"π΅ Total: ${total_cost:.2f}/hr")
elif args.action == "status":
if not args.job_id:
print("--job-id required for status")
return
instance = orch.get_job_status(args.job_id)
if instance:
print(f"π Job Status: {args.job_id}")
print(f" Provider: {instance.provider}")
print(f" GPU: {instance.gpu_type.value}")
print(f" Status: {instance.status}")
print(f" Price: ${instance.price_per_hour:.3f}/hr")
if instance.ssh_host:
print(f" SSH: {instance.ssh_user}@{instance.ssh_host}:{instance.ssh_port}")
else:
print(f"β Job not found: {args.job_id}")
elif args.action == "terminate":
if not args.job_id:
# Terminate all
count = orch.terminate_all_jobs()
print(f"π Terminated {count} jobs")
else:
success = orch.terminate_job(args.job_id)
if success:
print(f"β Job {args.job_id} terminated")
else:
print(f"β Failed to terminate {args.job_id}")
if __name__ == "__main__":
main()