Copy # train.py
"""PyTorch training script - runs on the Clore GPU."""
import os
import sys
import argparse
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import wandb
from datetime import datetime
def get_device():
"""Get the best available device."""
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"🎮 Using GPU: {torch.cuda.get_device_name(0)}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
device = torch.device("cpu")
print("⚠️ Using CPU (no GPU available)")
return device
def get_model(name: str, num_classes: int, pretrained: bool = True):
"""Get a model by name."""
model_fn = getattr(models, name, None)
if model_fn is None:
raise ValueError(f"Unknown model: {name}")
if pretrained:
weights = "IMAGENET1K_V1" if hasattr(models, f"{name.upper()}_Weights") else True
model = model_fn(weights=weights)
else:
model = model_fn()
# Replace final layer for our number of classes
if hasattr(model, 'fc'):
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif hasattr(model, 'classifier'):
if isinstance(model.classifier, nn.Sequential):
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
else:
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
return model
def get_data_loaders(dataset: str, batch_size: int, num_workers: int):
"""Get data loaders for training and validation."""
# Transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Dataset
if dataset.lower() == "cifar10":
train_data = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
val_data = datasets.CIFAR10('./data', train=False, download=True, transform=val_transform)
elif dataset.lower() == "cifar100":
train_data = datasets.CIFAR100('./data', train=True, download=True, transform=train_transform)
val_data = datasets.CIFAR100('./data', train=False, download=True, transform=val_transform)
else:
raise ValueError(f"Unknown dataset: {dataset}")
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
return train_loader, val_loader
def train_epoch(model, loader, criterion, optimizer, device, epoch, log_every=100):
"""Train for one epoch."""
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % log_every == 0:
print(f" Batch {batch_idx}/{len(loader)} | Loss: {loss.item():.4f} | "
f"Acc: {100.*correct/total:.1f}%")
wandb.log({
"train/batch_loss": loss.item(),
"train/batch_acc": 100. * correct / total,
"epoch": epoch,
"batch": batch_idx
})
return total_loss / len(loader), 100. * correct / total
def validate(model, loader, criterion, device):
"""Validate the model."""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return total_loss / len(loader), 100. * correct / total
def save_checkpoint(model, optimizer, epoch, loss, acc, path):
"""Save training checkpoint."""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'accuracy': acc
}, path)
print(f"💾 Checkpoint saved: {path}")
def load_checkpoint(model, optimizer, path):
"""Load training checkpoint."""
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(f"📂 Resumed from checkpoint (epoch {epoch})")
return epoch
return 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.json')
parser.add_argument('--resume', type=str, default=None)
args = parser.parse_args()
# Load config
with open(args.config) as f:
config = json.load(f)
# Setup device
device = get_device()
# Initialize wandb
wandb.init(
project=config.get('wandb_project', 'clore-training'),
name=config.get('wandb_run_name'),
config=config
)
# Model
model = get_model(
config['model_name'],
config['num_classes'],
config.get('pretrained', True)
).to(device)
print(f"📊 Model: {config['model_name']} ({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)")
# Data
train_loader, val_loader = get_data_loaders(
config['dataset'],
config['batch_size'],
config.get('num_workers', 4)
)
print(f"📦 Dataset: {config['dataset']} (train: {len(train_loader.dataset)}, val: {len(val_loader.dataset)})")
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config.get('weight_decay', 1e-4)
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
# Resume from checkpoint if provided
start_epoch = 0
if args.resume:
start_epoch = load_checkpoint(model, optimizer, args.resume)
# Training loop
best_acc = 0
checkpoint_dir = config.get('checkpoint_dir', '/workspace/checkpoints')
print(f"\n🚀 Starting training for {config['epochs']} epochs...")
print("="*60)
for epoch in range(start_epoch, config['epochs']):
print(f"\n📈 Epoch {epoch+1}/{config['epochs']}")
print("-"*40)
# Train
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device, epoch,
log_every=config.get('log_every', 100)
)
# Validate
val_loss, val_acc = validate(model, val_loader, criterion, device)
# Update scheduler
scheduler.step()
# Log
print(f"\n Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.1f}%")
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.1f}%")
wandb.log({
"train/epoch_loss": train_loss,
"train/epoch_acc": train_acc,
"val/loss": val_loss,
"val/acc": val_acc,
"lr": optimizer.param_groups[0]['lr'],
"epoch": epoch + 1
})
# Save checkpoint
if (epoch + 1) % config.get('checkpoint_every', 1) == 0:
save_checkpoint(
model, optimizer, epoch + 1, val_loss, val_acc,
f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pt"
)
# Save best model
if val_acc > best_acc:
best_acc = val_acc
save_checkpoint(
model, optimizer, epoch + 1, val_loss, val_acc,
f"{checkpoint_dir}/best_model.pt"
)
print(f" 🏆 New best accuracy: {best_acc:.1f}%")
print("\n" + "="*60)
print(f"✅ Training complete! Best accuracy: {best_acc:.1f}%")
print(f"📁 Checkpoints saved to: {checkpoint_dir}")
wandb.finish()
if __name__ == "__main__":
main()