# Distributed Training Across Multiple Servers

## What We're Building

A distributed PyTorch training system that orchestrates multiple Clore GPUs, handles inter-node communication, and provides linear scaling for large model training — at a fraction of traditional cloud costs.

## Prerequisites

* Clore.ai API key
* Python 3.10+
* Understanding of PyTorch DDP (DistributedDataParallel)

## Step 1: Multi-Node Cluster Manager

```python
# cluster_manager.py
"""Manage a cluster of Clore GPU nodes for distributed training."""

import time
import threading
from typing import List, Dict, Optional
from dataclasses import dataclass
import requests

@dataclass
class NodeConfig:
    """Configuration for a cluster node."""
    server_id: int
    gpus: List[str]
    price_usd: float
    host: str = ""
    port: int = 22
    rank: int = 0

class CloreCluster:
    """Manage a distributed training cluster on Clore."""
    
    BASE_URL = "https://api.clore.ai"
    
    def __init__(self, api_key: str, ssh_password: str = "DistTrain123!"):
        self.api_key = api_key
        self.headers = {"auth": api_key}
        self.ssh_password = ssh_password
        self.nodes: List[NodeConfig] = []
        self.orders: List[int] = []
        self.master_addr: str = ""
        self.master_port: int = 29500
    
    def _request(self, method: str, endpoint: str, **kwargs) -> Dict:
        url = f"{self.BASE_URL}{endpoint}"
        response = requests.request(method, url, headers=self.headers, **kwargs)
        data = response.json()
        if data.get("code") != 0:
            raise Exception(f"API Error: {data}")
        return data
    
    def find_servers(self, gpu_type: str, count: int, 
                     max_price: float = 1.0) -> List[Dict]:
        """Find multiple available servers with matching GPUs."""
        
        servers = self._request("GET", "/v1/marketplace")["servers"]
        
        candidates = []
        for server in servers:
            if server.get("rented"):
                continue
            
            gpus = server.get("gpu_array", [])
            if not any(gpu_type in g for g in gpus):
                continue
            
            price = server.get("price", {}).get("usd", {}).get("on_demand_clore", 999)
            if price > max_price:
                continue
            
            candidates.append({
                "id": server["id"],
                "gpus": gpus,
                "gpu_count": len(gpus),
                "price": price,
                "reliability": server.get("reliability", 0)
            })
        
        # Sort by reliability, then price
        candidates.sort(key=lambda x: (-x["reliability"], x["price"]))
        
        if len(candidates) < count:
            raise Exception(f"Only found {len(candidates)} servers, need {count}")
        
        return candidates[:count]
    
    def provision_cluster(self, gpu_type: str, num_nodes: int,
                          max_price: float = 1.0,
                          image: str = "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel"
                          ) -> List[NodeConfig]:
        """Provision a multi-node cluster."""
        
        print(f"🔍 Finding {num_nodes} {gpu_type} servers...")
        servers = self.find_servers(gpu_type, num_nodes, max_price)
        
        total_gpus = sum(s["gpu_count"] for s in servers)
        total_cost = sum(s["price"] for s in servers)
        print(f"   Found {len(servers)} servers with {total_gpus} GPUs @ ${total_cost:.2f}/hr total")
        
        # Provision all nodes in parallel
        print(f"\n📦 Provisioning {num_nodes} nodes...")
        threads = []
        results = [None] * num_nodes
        
        def provision_node(idx, server):
            order = self._request("POST", "/v1/create_order", json={
                "renting_server": server["id"],
                "type": "on-demand",
                "currency": "CLORE-Blockchain",
                "image": image,
                "ports": {
                    "22": "tcp",
                    "29500": "tcp",  # NCCL master port
                    "29501": "tcp",  # Additional NCCL port
                },
                "env": {
                    "NVIDIA_VISIBLE_DEVICES": "all",
                    "NCCL_DEBUG": "INFO"
                },
                "ssh_password": self.ssh_password
            })
            results[idx] = (server, order)
        
        for idx, server in enumerate(servers):
            t = threading.Thread(target=provision_node, args=(idx, server))
            threads.append(t)
            t.start()
        
        for t in threads:
            t.join()
        
        # Wait for all nodes to be ready
        print(f"⏳ Waiting for nodes to start...")
        for idx, (server, order) in enumerate(results):
            order_id = order["order_id"]
            self.orders.append(order_id)
            
            for _ in range(120):
                orders = self._request("GET", "/v1/my_orders")["orders"]
                current = next((o for o in orders if o["order_id"] == order_id), None)
                
                if current and current.get("status") == "running":
                    ssh_info = current["connection"]["ssh"]
                    parts = ssh_info.split()
                    host = parts[1].split("@")[1]
                    port = int(parts[3]) if len(parts) > 3 else 22
                    
                    node = NodeConfig(
                        server_id=server["id"],
                        gpus=server["gpus"],
                        price_usd=server["price"],
                        host=host,
                        port=port,
                        rank=idx
                    )
                    self.nodes.append(node)
                    print(f"   ✅ Node {idx}: {host}:{port} ({len(server['gpus'])} GPUs)")
                    break
                
                time.sleep(2)
            else:
                raise Exception(f"Timeout waiting for node {idx}")
        
        # Set master address (first node)
        self.master_addr = self.nodes[0].host
        
        return self.nodes
    
    def get_world_size(self) -> int:
        """Get total number of GPUs across all nodes."""
        return sum(len(node.gpus) for node in self.nodes)
    
    def get_cluster_info(self) -> Dict:
        """Get cluster information for distributed training."""
        return {
            "nodes": [
                {
                    "rank": node.rank,
                    "host": node.host,
                    "port": node.port,
                    "gpus": len(node.gpus)
                }
                for node in self.nodes
            ],
            "master_addr": self.master_addr,
            "master_port": self.master_port,
            "world_size": self.get_world_size(),
            "num_nodes": len(self.nodes),
            "total_cost_per_hour": sum(n.price_usd for n in self.nodes)
        }
    
    def shutdown(self):
        """Shutdown the cluster."""
        print("\n🧹 Shutting down cluster...")
        for order_id in self.orders:
            try:
                self._request("POST", "/v1/cancel_order", json={"id": order_id})
                print(f"   ✅ Cancelled order {order_id}")
            except Exception as e:
                print(f"   ⚠️  Failed to cancel {order_id}: {e}")
        
        self.nodes = []
        self.orders = []


if __name__ == "__main__":
    import sys
    
    api_key = sys.argv[1] if len(sys.argv) > 1 else "YOUR_API_KEY"
    
    cluster = CloreCluster(api_key)
    
    try:
        # Provision 2-node cluster
        nodes = cluster.provision_cluster(
            gpu_type="RTX 4090",
            num_nodes=2,
            max_price=0.50
        )
        
        info = cluster.get_cluster_info()
        print(f"\n📊 Cluster Ready:")
        print(f"   Nodes: {info['num_nodes']}")
        print(f"   Total GPUs: {info['world_size']}")
        print(f"   Cost: ${info['total_cost_per_hour']:.2f}/hr")
        print(f"   Master: {info['master_addr']}:{info['master_port']}")
        
        input("\nPress Enter to shutdown cluster...")
        
    finally:
        cluster.shutdown()
```

## Step 2: Distributed Training Script

```python
# distributed_train.py
"""Distributed PyTorch training script for multi-node clusters."""

import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms, models
import time

def setup_distributed():
    """Initialize distributed training."""
    
    # Get distributed info from environment
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    master_addr = os.environ.get("MASTER_ADDR", "localhost")
    master_port = os.environ.get("MASTER_PORT", "29500")
    
    # Set device
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    
    # Initialize process group
    if world_size > 1:
        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://{master_addr}:{master_port}",
            world_size=world_size,
            rank=rank
        )
        print(f"[Rank {rank}] Initialized (local_rank={local_rank}, world_size={world_size})")
    
    return rank, world_size, local_rank, device


def cleanup_distributed():
    """Cleanup distributed training."""
    if dist.is_initialized():
        dist.destroy_process_group()


def get_model(name: str, num_classes: int, device):
    """Get model and wrap with DDP if distributed."""
    
    model_fn = getattr(models, name, None)
    if model_fn is None:
        raise ValueError(f"Unknown model: {name}")
    
    model = model_fn(weights="IMAGENET1K_V1")
    
    # Modify classifier
    if hasattr(model, 'fc'):
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    model = model.to(device)
    
    # Wrap with DDP if distributed
    if dist.is_initialized():
        model = DDP(model, device_ids=[device.index])
    
    return model


def get_data_loaders(batch_size: int, world_size: int, rank: int):
    """Get distributed data loaders."""
    
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    
    # Use distributed sampler
    if world_size > 1:
        sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank)
    else:
        sampler = None
    
    loader = DataLoader(
        train_data, 
        batch_size=batch_size,
        shuffle=(sampler is None),
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )
    
    return loader, sampler


def train_epoch(model, loader, criterion, optimizer, device, rank, epoch, sampler=None):
    """Train for one epoch."""
    
    model.train()
    
    if sampler:
        sampler.set_epoch(epoch)
    
    total_loss = 0
    correct = 0
    total = 0
    
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % 50 == 0 and rank == 0:
            throughput = total / (time.time() - start_time)
            print(f"  [Epoch {epoch}] Batch {batch_idx}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | Throughput: {throughput:.1f} samples/s")
    
    return total_loss / len(loader), 100. * correct / total


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="resnet18")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.001)
    args = parser.parse_args()
    
    # Setup distributed
    rank, world_size, local_rank, device = setup_distributed()
    
    is_main = rank == 0
    
    if is_main:
        print(f"\n🚀 Distributed Training")
        print(f"   World Size: {world_size}")
        print(f"   Batch Size per GPU: {args.batch_size}")
        print(f"   Effective Batch Size: {args.batch_size * world_size}")
        print(f"   Model: {args.model}\n")
    
    # Model
    model = get_model(args.model, num_classes=10, device=device)
    
    # Data
    # Scale batch size per GPU
    loader, sampler = get_data_loaders(args.batch_size, world_size, rank)
    
    # Training
    criterion = nn.CrossEntropyLoss()
    
    # Scale learning rate with world size
    scaled_lr = args.lr * world_size
    optimizer = optim.AdamW(model.parameters(), lr=scaled_lr)
    
    # Train
    for epoch in range(args.epochs):
        if is_main:
            print(f"\n📈 Epoch {epoch + 1}/{args.epochs}")
        
        loss, acc = train_epoch(model, loader, criterion, optimizer, 
                               device, rank, epoch, sampler)
        
        # Sync metrics
        if world_size > 1:
            loss_tensor = torch.tensor([loss]).to(device)
            acc_tensor = torch.tensor([acc]).to(device)
            dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
            dist.all_reduce(acc_tensor, op=dist.ReduceOp.AVG)
            loss = loss_tensor.item()
            acc = acc_tensor.item()
        
        if is_main:
            print(f"  Loss: {loss:.4f} | Accuracy: {acc:.1f}%")
    
    # Save model (main process only)
    if is_main:
        model_to_save = model.module if hasattr(model, 'module') else model
        torch.save(model_to_save.state_dict(), "distributed_model.pt")
        print("\n✅ Training complete! Model saved to distributed_model.pt")
    
    cleanup_distributed()


if __name__ == "__main__":
    main()
```

## Step 3: Cluster Training Orchestrator

```python
# run_distributed.py
"""Orchestrate distributed training across Clore cluster."""

import os
import time
import json
import threading
import paramiko
from typing import List, Dict
from cluster_manager import CloreCluster, NodeConfig

class DistributedTrainer:
    """Run distributed training on Clore cluster."""
    
    def __init__(self, cluster: CloreCluster):
        self.cluster = cluster
        self.ssh_clients: Dict[int, paramiko.SSHClient] = {}
    
    def connect_all(self):
        """Connect to all nodes via SSH."""
        for node in self.cluster.nodes:
            client = paramiko.SSHClient()
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            
            for attempt in range(5):
                try:
                    client.connect(
                        node.host, 
                        port=node.port,
                        username="root",
                        password=self.cluster.ssh_password,
                        timeout=30
                    )
                    self.ssh_clients[node.rank] = client
                    print(f"✅ Connected to node {node.rank}")
                    break
                except Exception as e:
                    if attempt < 4:
                        time.sleep(5)
                    else:
                        raise e
    
    def run_on_node(self, rank: int, command: str, stream: bool = False) -> str:
        """Run command on specific node."""
        client = self.ssh_clients[rank]
        stdin, stdout, stderr = client.exec_command(command, get_pty=True)
        
        if stream:
            output = ""
            for line in iter(stdout.readline, ""):
                print(f"[Node {rank}] {line}", end="")
                output += line
            return output
        else:
            return stdout.read().decode()
    
    def run_on_all(self, command: str, parallel: bool = True):
        """Run command on all nodes."""
        if parallel:
            threads = []
            for rank in self.ssh_clients:
                t = threading.Thread(
                    target=self.run_on_node,
                    args=(rank, command, False)
                )
                threads.append(t)
                t.start()
            for t in threads:
                t.join()
        else:
            for rank in self.ssh_clients:
                self.run_on_node(rank, command)
    
    def setup_nodes(self):
        """Setup all nodes for distributed training."""
        print("\n🔧 Setting up nodes...")
        
        # Install dependencies
        self.run_on_all("pip install --upgrade pip && pip install wandb tensorboard")
        
        # Create workspace
        self.run_on_all("mkdir -p /workspace")
        
        print("✅ All nodes ready")
    
    def upload_script(self, local_path: str, remote_path: str = "/workspace/train.py"):
        """Upload training script to all nodes."""
        from scp import SCPClient
        
        for rank, client in self.ssh_clients.items():
            with SCPClient(client.get_transport()) as scp:
                scp.put(local_path, remote_path)
            print(f"📤 Uploaded to node {rank}")
    
    def run_distributed_training(
        self,
        train_script: str,
        args: str = "",
        timeout_minutes: int = 60
    ):
        """Run distributed training across the cluster."""
        
        cluster_info = self.cluster.get_cluster_info()
        master_addr = cluster_info["master_addr"]
        master_port = cluster_info["master_port"]
        world_size = cluster_info["world_size"]
        
        print(f"\n🚀 Starting distributed training")
        print(f"   Master: {master_addr}:{master_port}")
        print(f"   World Size: {world_size}")
        
        # Calculate local ranks per node
        local_ranks = {}
        current_rank = 0
        for node in self.cluster.nodes:
            local_ranks[node.rank] = (current_rank, len(node.gpus))
            current_rank += len(node.gpus)
        
        # Start training on all nodes
        threads = []
        
        def run_node_training(node_rank: int):
            node = self.cluster.nodes[node_rank]
            start_rank, num_gpus = local_ranks[node_rank]
            
            # Build torchrun command
            cmd = f"""
export MASTER_ADDR={master_addr}
export MASTER_PORT={master_port}
export WORLD_SIZE={world_size}

cd /workspace && torchrun \\
    --nproc_per_node={num_gpus} \\
    --nnodes={len(self.cluster.nodes)} \\
    --node_rank={node_rank} \\
    --master_addr={master_addr} \\
    --master_port={master_port} \\
    {train_script} {args}
"""
            self.run_on_node(node_rank, cmd, stream=True)
        
        for node in self.cluster.nodes:
            t = threading.Thread(target=run_node_training, args=(node.rank,))
            threads.append(t)
            t.start()
        
        for t in threads:
            t.join()
        
        print("\n✅ Distributed training complete!")
    
    def download_results(self, local_dir: str = "./distributed_results"):
        """Download results from master node."""
        from scp import SCPClient
        
        os.makedirs(local_dir, exist_ok=True)
        
        # Download from master (rank 0)
        client = self.ssh_clients[0]
        with SCPClient(client.get_transport()) as scp:
            scp.get("/workspace/distributed_model.pt", local_dir)
        
        print(f"📥 Downloaded results to {local_dir}")
    
    def close(self):
        """Close all SSH connections."""
        for client in self.ssh_clients.values():
            client.close()


def main():
    import sys
    
    api_key = sys.argv[1] if len(sys.argv) > 1 else "YOUR_API_KEY"
    
    # Create cluster
    cluster = CloreCluster(api_key)
    
    try:
        # Provision 2-node cluster
        nodes = cluster.provision_cluster(
            gpu_type="RTX 4090",
            num_nodes=2,
            max_price=0.50
        )
        
        # Create trainer
        trainer = DistributedTrainer(cluster)
        trainer.connect_all()
        trainer.setup_nodes()
        
        # Upload training script
        trainer.upload_script("distributed_train.py")
        
        # Run training
        trainer.run_distributed_training(
            train_script="train.py",
            args="--model resnet18 --epochs 5 --batch-size 64"
        )
        
        # Download results
        trainer.download_results()
        
        # Print summary
        info = cluster.get_cluster_info()
        print(f"\n📊 Training Summary:")
        print(f"   Nodes: {info['num_nodes']}")
        print(f"   Total GPUs: {info['world_size']}")
        print(f"   Cost/hr: ${info['total_cost_per_hour']:.2f}")
        
    finally:
        trainer.close()
        cluster.shutdown()


if __name__ == "__main__":
    main()
```

## Scaling Guide

| Nodes | GPUs | Effective Batch | Est. Speedup | Cost/hr |
| ----- | ---- | --------------- | ------------ | ------- |
| 1     | 1    | 64              | 1.0x         | \~$0.35 |
| 1     | 2    | 128             | 1.9x         | \~$0.70 |
| 2     | 4    | 256             | 3.6x         | \~$1.40 |
| 4     | 8    | 512             | 6.8x         | \~$2.80 |
| 8     | 16   | 1024            | 12.0x        | \~$5.60 |

*Speedup factors assume near-linear scaling with proper batch size adjustment*

## Best Practices

1. **Scale batch size with GPUs**: `effective_batch = per_gpu_batch * world_size`
2. **Scale learning rate**: `scaled_lr = base_lr * world_size` (linear scaling rule)
3. **Use gradient accumulation** for very large effective batches
4. **Enable NCCL optimizations**: Set `NCCL_DEBUG=INFO` for debugging

## Cost Comparison: 8-GPU Training

| Provider | Config        | Cost/hr  | 10hr Training |
| -------- | ------------- | -------- | ------------- |
| Clore.ai | 2x4-GPU nodes | \~$2.80  | \~$28         |
| AWS      | p4d.24xlarge  | \~$32.77 | \~$328        |
| GCP      | a2-highgpu-8g | \~$29.39 | \~$294        |

**Savings: \~90% compared to major cloud providers**

## Next Steps

* [Fine-Tuning Models with Hugging Face](https://docs.clore.ai/dev/machine-learning-and-training/huggingface-finetuning)
* [Hyperparameter Sweeps with Optuna](https://docs.clore.ai/dev/machine-learning-and-training/hyperparameter-sweeps)
* [Auto-Scaling ML Training Pipeline](https://docs.clore.ai/dev/machine-learning-and-training/auto-scaling-pipeline)
