# Training a PyTorch Model on Clore

## What We're Building

A complete PyTorch training pipeline that automatically provisions a GPU, trains a model with checkpointing, logs metrics to Weights & Biases, and handles cleanup — all for a fraction of cloud GPU costs.

## Prerequisites

* Clore.ai API key
* Python 3.10+
* PyTorch experience

## Step 1: Project Setup

```python
# requirements.txt
torch>=2.7.1
torchvision>=0.22.0
wandb>=0.15.0
requests>=2.28.0
paramiko>=3.0.0
scp>=0.14.0
```

```python
# config.py
"""Training configuration."""

from dataclasses import dataclass
from typing import Optional, List

@dataclass
class TrainingConfig:
    # Model
    model_name: str = "resnet18"
    num_classes: int = 10
    pretrained: bool = True
    
    # Training
    epochs: int = 10
    batch_size: int = 64
    learning_rate: float = 0.001
    weight_decay: float = 1e-4
    
    # Data
    dataset: str = "cifar10"
    num_workers: int = 4
    
    # Checkpointing
    checkpoint_dir: str = "/workspace/checkpoints"
    checkpoint_every: int = 1  # epochs
    
    # Logging
    wandb_project: str = "clore-training"
    wandb_run_name: Optional[str] = None
    log_every: int = 100  # batches
    
    # GPU
    gpu_type: str = "RTX 4090"
    max_price_usd: float = 0.50

@dataclass
class CloreConfig:
    api_key: str
    image: str = "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel"
    ssh_password: str = "PyTorchTrain123!"
    ports: dict = None
    
    def __post_init__(self):
        self.ports = self.ports or {"22": "tcp", "6006": "http"}
```

## Step 2: The Training Script (Runs on GPU)

```python
# train.py
"""PyTorch training script - runs on the Clore GPU."""

import os
import sys
import argparse
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import wandb
from datetime import datetime


def get_device():
    """Get the best available device."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🎮 Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        device = torch.device("cpu")
        print("⚠️  Using CPU (no GPU available)")
    return device


def get_model(name: str, num_classes: int, pretrained: bool = True):
    """Get a model by name."""
    
    model_fn = getattr(models, name, None)
    if model_fn is None:
        raise ValueError(f"Unknown model: {name}")
    
    if pretrained:
        weights = "IMAGENET1K_V1" if hasattr(models, f"{name.upper()}_Weights") else True
        model = model_fn(weights=weights)
    else:
        model = model_fn()
    
    # Replace final layer for our number of classes
    if hasattr(model, 'fc'):
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif hasattr(model, 'classifier'):
        if isinstance(model.classifier, nn.Sequential):
            model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
        else:
            model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    
    return model


def get_data_loaders(dataset: str, batch_size: int, num_workers: int):
    """Get data loaders for training and validation."""
    
    # Transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Dataset
    if dataset.lower() == "cifar10":
        train_data = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
        val_data = datasets.CIFAR10('./data', train=False, download=True, transform=val_transform)
    elif dataset.lower() == "cifar100":
        train_data = datasets.CIFAR100('./data', train=True, download=True, transform=train_transform)
        val_data = datasets.CIFAR100('./data', train=False, download=True, transform=val_transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)
    
    return train_loader, val_loader


def train_epoch(model, loader, criterion, optimizer, device, epoch, log_every=100):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % log_every == 0:
            print(f"  Batch {batch_idx}/{len(loader)} | Loss: {loss.item():.4f} | "
                  f"Acc: {100.*correct/total:.1f}%")
            
            wandb.log({
                "train/batch_loss": loss.item(),
                "train/batch_acc": 100. * correct / total,
                "epoch": epoch,
                "batch": batch_idx
            })
    
    return total_loss / len(loader), 100. * correct / total


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total


def save_checkpoint(model, optimizer, epoch, loss, acc, path):
    """Save training checkpoint."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'accuracy': acc
    }, path)
    print(f"💾 Checkpoint saved: {path}")


def load_checkpoint(model, optimizer, path):
    """Load training checkpoint."""
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        print(f"📂 Resumed from checkpoint (epoch {epoch})")
        return epoch
    return 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config.json')
    parser.add_argument('--resume', type=str, default=None)
    args = parser.parse_args()
    
    # Load config
    with open(args.config) as f:
        config = json.load(f)
    
    # Setup device
    device = get_device()
    
    # Initialize wandb
    wandb.init(
        project=config.get('wandb_project', 'clore-training'),
        name=config.get('wandb_run_name'),
        config=config
    )
    
    # Model
    model = get_model(
        config['model_name'],
        config['num_classes'],
        config.get('pretrained', True)
    ).to(device)
    
    print(f"📊 Model: {config['model_name']} ({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)")
    
    # Data
    train_loader, val_loader = get_data_loaders(
        config['dataset'],
        config['batch_size'],
        config.get('num_workers', 4)
    )
    print(f"📦 Dataset: {config['dataset']} (train: {len(train_loader.dataset)}, val: {len(val_loader.dataset)})")
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config.get('weight_decay', 1e-4)
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
    
    # Resume from checkpoint if provided
    start_epoch = 0
    if args.resume:
        start_epoch = load_checkpoint(model, optimizer, args.resume)
    
    # Training loop
    best_acc = 0
    checkpoint_dir = config.get('checkpoint_dir', '/workspace/checkpoints')
    
    print(f"\n🚀 Starting training for {config['epochs']} epochs...")
    print("="*60)
    
    for epoch in range(start_epoch, config['epochs']):
        print(f"\n📈 Epoch {epoch+1}/{config['epochs']}")
        print("-"*40)
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch,
            log_every=config.get('log_every', 100)
        )
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Log
        print(f"\n  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.1f}%")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.1f}%")
        
        wandb.log({
            "train/epoch_loss": train_loss,
            "train/epoch_acc": train_acc,
            "val/loss": val_loss,
            "val/acc": val_acc,
            "lr": optimizer.param_groups[0]['lr'],
            "epoch": epoch + 1
        })
        
        # Save checkpoint
        if (epoch + 1) % config.get('checkpoint_every', 1) == 0:
            save_checkpoint(
                model, optimizer, epoch + 1, val_loss, val_acc,
                f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pt"
            )
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(
                model, optimizer, epoch + 1, val_loss, val_acc,
                f"{checkpoint_dir}/best_model.pt"
            )
            print(f"  🏆 New best accuracy: {best_acc:.1f}%")
    
    print("\n" + "="*60)
    print(f"✅ Training complete! Best accuracy: {best_acc:.1f}%")
    print(f"📁 Checkpoints saved to: {checkpoint_dir}")
    
    wandb.finish()


if __name__ == "__main__":
    main()
```

## Step 3: Remote Training Orchestrator

```python
# remote_trainer.py
"""Orchestrate PyTorch training on Clore GPUs."""

import os
import time
import json
import tempfile
import paramiko
from scp import SCPClient
from typing import Dict, Optional
import requests

class CloreTrainer:
    """Run PyTorch training jobs on Clore GPUs."""
    
    BASE_URL = "https://api.clore.ai"
    
    def __init__(self, api_key: str, wandb_key: Optional[str] = None):
        self.api_key = api_key
        self.wandb_key = wandb_key
        self.headers = {"auth": api_key}
        self.ssh_client = None
        self.current_order = None
    
    def _request(self, method: str, endpoint: str, **kwargs) -> Dict:
        url = f"{self.BASE_URL}{endpoint}"
        response = requests.request(method, url, headers=self.headers, **kwargs)
        data = response.json()
        if data.get("code") != 0:
            raise Exception(f"API Error: {data}")
        return data
    
    def find_gpu(self, gpu_type: str = "RTX 4090", 
                  max_price: float = 0.50) -> Dict:
        """Find a suitable GPU."""
        servers = self._request("GET", "/v1/marketplace")["servers"]
        
        for server in servers:
            if server.get("rented"):
                continue
            
            gpus = server.get("gpu_array", [])
            if not any(gpu_type in g for g in gpus):
                continue
            
            price = server.get("price", {}).get("usd", {}).get("on_demand_clore", 999)
            if price <= max_price:
                return {
                    "id": server["id"],
                    "gpus": gpus,
                    "price": price,
                    "specs": server.get("specs", {})
                }
        
        raise Exception(f"No {gpu_type} available under ${max_price}/hr")
    
    def provision(self, server_id: int, image: str, ssh_password: str) -> Dict:
        """Provision a GPU server."""
        
        order = self._request("POST", "/v1/create_order", json={
            "renting_server": server_id,
            "type": "on-demand",
            "currency": "CLORE-Blockchain",
            "image": image,
            "ports": {"22": "tcp", "6006": "http"},
            "env": {"NVIDIA_VISIBLE_DEVICES": "all"},
            "ssh_password": ssh_password
        })
        
        order_id = order["order_id"]
        print(f"📦 Created order {order_id}")
        
        # Wait for ready
        for _ in range(120):
            orders = self._request("GET", "/v1/my_orders")["orders"]
            current = next((o for o in orders if o["order_id"] == order_id), None)
            if current and current.get("status") == "running":
                self.current_order = current
                return current
            time.sleep(2)
        
        raise Exception("Timeout waiting for server")
    
    def connect_ssh(self, host: str, port: int, password: str):
        """Connect via SSH."""
        self.ssh_client = paramiko.SSHClient()
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        
        for attempt in range(5):
            try:
                self.ssh_client.connect(host, port=port, username="root", 
                                        password=password, timeout=30)
                print(f"✅ SSH connected to {host}:{port}")
                return
            except Exception as e:
                if attempt < 4:
                    print(f"⏳ Connection attempt {attempt+1}/5...")
                    time.sleep(10)
                else:
                    raise e
    
    def upload_files(self, local_files: Dict[str, str]):
        """Upload files to server."""
        with SCPClient(self.ssh_client.get_transport()) as scp:
            for local_path, remote_path in local_files.items():
                scp.put(local_path, remote_path)
                print(f"📤 Uploaded {local_path} → {remote_path}")
    
    def run_command(self, command: str, stream: bool = True) -> str:
        """Run a command on the server."""
        stdin, stdout, stderr = self.ssh_client.exec_command(command, get_pty=True)
        
        output = ""
        if stream:
            for line in iter(stdout.readline, ""):
                print(line, end="")
                output += line
        else:
            output = stdout.read().decode()
        
        return output
    
    def setup_environment(self):
        """Setup the training environment."""
        print("\n🔧 Setting up environment...")
        
        commands = [
            "pip install --upgrade pip",
            "pip install wandb",
            "mkdir -p /workspace/checkpoints"
        ]
        
        if self.wandb_key:
            commands.append(f"wandb login {self.wandb_key}")
        
        for cmd in commands:
            self.run_command(cmd, stream=False)
        
        print("✅ Environment ready")
    
    def train(self, config: Dict, train_script: str = "train.py"):
        """Run the training job."""
        
        # Save config
        config_path = "/workspace/config.json"
        self.run_command(f"cat > {config_path} << 'EOF'\n{json.dumps(config, indent=2)}\nEOF")
        
        # Run training
        print("\n🚀 Starting training...")
        self.run_command(f"cd /workspace && python {train_script} --config {config_path}")
    
    def download_results(self, remote_dir: str, local_dir: str):
        """Download training results."""
        os.makedirs(local_dir, exist_ok=True)
        
        with SCPClient(self.ssh_client.get_transport()) as scp:
            scp.get(remote_dir, local_dir, recursive=True)
        
        print(f"📥 Downloaded results to {local_dir}")
    
    def cleanup(self):
        """Cleanup resources."""
        if self.ssh_client:
            self.ssh_client.close()
        
        if self.current_order:
            order_id = self.current_order["order_id"]
            self._request("POST", "/v1/cancel_order", json={"id": order_id})
            print(f"✅ Order {order_id} cancelled")
    
    def run_training_job(
        self,
        config: Dict,
        train_script_path: str,
        gpu_type: str = "RTX 4090",
        max_price: float = 0.50,
        image: str = "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel",
        ssh_password: str = "PyTorchTrain123!",
        output_dir: str = "./training_results"
    ):
        """Run a complete training job."""
        
        start_time = time.time()
        
        try:
            # Find GPU
            print(f"🔍 Finding {gpu_type} under ${max_price}/hr...")
            gpu = self.find_gpu(gpu_type, max_price)
            print(f"   Found server {gpu['id']}: {gpu['gpus']} @ ${gpu['price']:.2f}/hr")
            
            # Provision
            print(f"\n📦 Provisioning server...")
            order = self.provision(gpu["id"], image, ssh_password)
            
            # Parse connection
            ssh_info = order["connection"]["ssh"]
            parts = ssh_info.split()
            host = parts[1].split("@")[1]
            port = int(parts[3]) if len(parts) > 3 else 22
            
            # Connect
            print(f"\n🔗 Connecting to {host}:{port}...")
            self.connect_ssh(host, port, ssh_password)
            
            # Setup
            self.setup_environment()
            
            # Upload training script
            self.upload_files({train_script_path: "/workspace/train.py"})
            
            # Train
            self.train(config)
            
            # Download results
            self.download_results("/workspace/checkpoints", output_dir)
            
            # Calculate cost
            duration = (time.time() - start_time) / 3600
            cost = duration * gpu["price"]
            
            print("\n" + "="*60)
            print("✅ Training complete!")
            print(f"⏱️  Duration: {duration:.2f} hours")
            print(f"💰 Cost: ${cost:.2f}")
            print(f"📁 Results: {output_dir}")
            
        finally:
            self.cleanup()


# Usage
if __name__ == "__main__":
    import sys
    
    api_key = sys.argv[1] if len(sys.argv) > 1 else "YOUR_API_KEY"
    wandb_key = os.environ.get("WANDB_API_KEY")
    
    trainer = CloreTrainer(api_key, wandb_key)
    
    config = {
        "model_name": "resnet18",
        "num_classes": 10,
        "pretrained": True,
        "epochs": 5,
        "batch_size": 64,
        "learning_rate": 0.001,
        "dataset": "cifar10",
        "wandb_project": "clore-pytorch-demo",
        "checkpoint_dir": "/workspace/checkpoints",
        "log_every": 50
    }
    
    trainer.run_training_job(
        config=config,
        train_script_path="train.py",
        gpu_type="RTX 4090",
        max_price=0.50,
        output_dir="./my_training_results"
    )
```

## Step 4: One-Shot Training Script

```python
#!/usr/bin/env python3
"""
One-shot: Train a model on Clore GPU with a single command.

Usage:
    python run_training.py --api-key YOUR_KEY --model resnet50 --epochs 10
"""

import argparse
import os
import sys
from remote_trainer import CloreTrainer

def main():
    parser = argparse.ArgumentParser(description="Train PyTorch models on Clore GPUs")
    
    # Required
    parser.add_argument("--api-key", required=True, help="Clore API key")
    
    # Model config
    parser.add_argument("--model", default="resnet18", help="Model architecture")
    parser.add_argument("--dataset", default="cifar10", help="Dataset to use")
    parser.add_argument("--num-classes", type=int, default=10, help="Number of classes")
    parser.add_argument("--pretrained", action="store_true", help="Use pretrained weights")
    
    # Training config
    parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    
    # GPU config
    parser.add_argument("--gpu-type", default="RTX 4090", help="GPU type to rent")
    parser.add_argument("--max-price", type=float, default=0.50, help="Max hourly price")
    
    # Output
    parser.add_argument("--output", default="./results", help="Output directory")
    parser.add_argument("--wandb-project", default="clore-training", help="W&B project")
    
    args = parser.parse_args()
    
    # Build config
    config = {
        "model_name": args.model,
        "num_classes": args.num_classes,
        "pretrained": args.pretrained,
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "learning_rate": args.lr,
        "dataset": args.dataset,
        "wandb_project": args.wandb_project,
        "checkpoint_dir": "/workspace/checkpoints"
    }
    
    # Get wandb key from env
    wandb_key = os.environ.get("WANDB_API_KEY")
    
    # Run training
    trainer = CloreTrainer(args.api_key, wandb_key)
    
    print("🚀 Clore PyTorch Training")
    print("="*50)
    print(f"Model: {args.model}")
    print(f"Dataset: {args.dataset}")
    print(f"Epochs: {args.epochs}")
    print(f"GPU: {args.gpu_type} (max ${args.max_price}/hr)")
    print("="*50 + "\n")
    
    trainer.run_training_job(
        config=config,
        train_script_path=os.path.join(os.path.dirname(__file__), "train.py"),
        gpu_type=args.gpu_type,
        max_price=args.max_price,
        output_dir=args.output
    )


if __name__ == "__main__":
    main()
```

## Quick Start

```bash
# Set your API key
export CLORE_API_KEY="your_api_key"
export WANDB_API_KEY="your_wandb_key"  # Optional

# Train ResNet18 on CIFAR-10
python run_training.py \
    --api-key $CLORE_API_KEY \
    --model resnet18 \
    --dataset cifar10 \
    --epochs 10 \
    --pretrained \
    --gpu-type "RTX 4090" \
    --max-price 0.50

# Train larger model
python run_training.py \
    --api-key $CLORE_API_KEY \
    --model resnet50 \
    --epochs 20 \
    --batch-size 32 \
    --gpu-type "RTX 4090"
```

## Cost Comparison

| Model    | Epochs | Clore (RTX 4090) | AWS (p4d.24xlarge) | Savings |
| -------- | ------ | ---------------- | ------------------ | ------- |
| ResNet18 | 10     | \~$0.20          | \~$5.30            | 96%     |
| ResNet50 | 20     | \~$0.80          | \~$21.20           | 96%     |
| ViT-Base | 50     | \~$4.00          | \~$106.00          | 96%     |

*Estimated based on \~$0.40/hr Clore RTX 4090 vs \~$10.60/hr AWS A100*

## Next Steps

* [Distributed Training Across Multiple Servers](https://docs.clore.ai/dev/machine-learning-and-training/distributed-training)
* [Fine-Tuning Models with Hugging Face](https://docs.clore.ai/dev/machine-learning-and-training/huggingface-finetuning)
* [Hyperparameter Sweeps with Optuna](https://docs.clore.ai/dev/machine-learning-and-training/hyperparameter-sweeps)
