# 批处理

在 CLORE.AI GPU 上高效处理大规模工作负载。

{% hint style="success" %}
在此查找合适的 GPU： [CLORE.AI 市场](https://clore.ai/marketplace).
{% endhint %}

## 使用 clore-ai SDK 进行批量基础设施（推荐）

官方 SDK 使批量 GPU 配置变得简单，并支持异步：

```python
import asyncio
from clore_ai import AsyncCloreAI

async def batch_deploy(server_ids):
    """在多台服务器上并发部署。"""
    async with AsyncCloreAI() as client:
        tasks = [
            client.create_order(
                server_id=sid,
                image="cloreai/ubuntu22.04-cuda12",
                type="on-demand",
                currency="bitcoin",
                ssh_password="BatchPass123",
                ports={"22": "tcp"}
            )
            for sid in server_ids
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        for sid, result in zip(server_ids, results):
            if isinstance(result, Exception):
                print(f"❌ Server {sid}: {result}")
            else:
                print(f"✅ Server {sid}: Order {result.id}")
        return results

# 同时在 5 台服务器上部署
asyncio.run(batch_deploy([142, 305, 891, 450, 612]))
```

→ 参见 [Python SDK 指南](https://docs.clore.ai/guides/guides_v2-zh/gao-ji/python-sdk) 和 [CLI 自动化](https://docs.clore.ai/guides/guides_v2-zh/gao-ji/cli-automation) 以了解更多。

***

## 何时使用批处理

* 处理数百/数千个项目
* 转换大型数据集
* 生成大量图像/视频
* 批量转录
* 训练数据准备

***

## LLM 批处理

### vLLM 批处理 API

vLLM 通过连续批处理自动处理批量：

```python
from openai import OpenAI
import asyncio
import aiohttp

client = OpenAI(base_url="http://server:8000/v1", api_key="dummy")

# 同步批处理
def process_batch_sync(prompts):
    results = []
    for prompt in prompts:
        response = client.chat.completions.create(
            model="meta-llama/Llama-3.1-8B-Instruct",
            messages=[{"role": "user", "content": prompt}]
        )
        results.append(response.choices[0].message.content)
    return results

# 处理 100 条提示
prompts = [f"Summarize topic {i}" for i in range(100)]
results = process_batch_sync(prompts)
```

### 异步批处理（更快）

```python
import asyncio
from openai import AsyncOpenAI

client = AsyncOpenAI(base_url="http://server:8000/v1", api_key="dummy")

async def process_single(prompt):
    response = await client.chat.completions.create(
        model="meta-llama/Llama-3.1-8B-Instruct",
        messages=[{"role": "user", "content": prompt}]
    )
    return response.choices[0].message.content

async def process_batch_async(prompts, max_concurrent=10):
    semaphore = asyncio.Semaphore(max_concurrent)

    async def limited_process(prompt):
        async with semaphore:
            return await process_single(prompt)

    tasks = [limited_process(p) for p in prompts]
    return await asyncio.gather(*tasks)

# 使用 10 个并发请求处理 1000 条提示
prompts = [f"Generate description for product {i}" for i in range(1000)]
results = asyncio.run(process_batch_async(prompts, max_concurrent=10))
```

### 带进度跟踪的批处理

```python
import asyncio
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI

client = AsyncOpenAI(base_url="http://server:8000/v1", api_key="dummy")

async def process_with_progress(prompts, max_concurrent=10):
    semaphore = asyncio.Semaphore(max_concurrent)
    results = []

    async def process_one(prompt, idx):
        async with semaphore:
            response = await client.chat.completions.create(
                model="meta-llama/Llama-3.1-8B-Instruct",
                messages=[{"role": "user", "content": prompt}]
            )
            return idx, response.choices[0].message.content

    tasks = [process_one(p, i) for i, p in enumerate(prompts)]

    for coro in tqdm.as_completed(tasks, total=len(tasks)):
        idx, result = await coro
        results.append((idx, result))

    # 按原始顺序排序
    results.sort(key=lambda x: x[0])
    return [r[1] for r in results]

# 运行
prompts = ["..." for _ in range(500)]
results = asyncio.run(process_with_progress(prompts))
```

### 为长批次保存进度

```python
import json
from pathlib import Path

def process_batch_with_checkpoint(prompts, checkpoint_file="checkpoint.json"):
    # 加载检查点
    checkpoint = Path(checkpoint_file)
    if checkpoint.exists():
        with open(checkpoint) as f:
            data = json.load(f)
            results = data['results']
            start_idx = data['last_completed'] + 1
        print(f"从索引 {start_idx} 恢复")
    else:
        results = [None] * len(prompts)
        start_idx = 0

    # 处理剩余项
    for i in range(start_idx, len(prompts)):
        try:
            response = client.chat.completions.create(
                model="meta-llama/Llama-3.1-8B-Instruct",
                messages=[{"role": "user", "content": prompts[i]}]
            )
            results[i] = response.choices[0].message.content

            # 每 10 项保存一次检查点
            if i % 10 == 0:
                with open(checkpoint_file, 'w') as f:
                    json.dump({'results': results, 'last_completed': i}, f)
                print(f"在 {i} 处保存了检查点")

        except Exception as e:
            print(f"在 {i} 处出错：{e}")
            # 出错时保存检查点
            with open(checkpoint_file, 'w') as f:
                json.dump({'results': results, 'last_completed': i - 1}, f)
            raise

    # 完成后清理检查点
    if checkpoint.exists():
        checkpoint.unlink()

    return results
```

***

## 图像生成批处理

### SD WebUI 批处理

```python
import requests
import base64
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

SD_API = "http://server:7860"

def generate_image(prompt, output_path):
    response = requests.post(f'{SD_API}/sdapi/v1/txt2img', json={
        'prompt': prompt,
        'negative_prompt': 'blurry, low quality',
        'steps': 20,
        'width': 512,
        'height': 512
    })

    image_data = base64.b64decode(response.json()['images'][0])

    with open(output_path, 'wb') as f:
        f.write(image_data)

    return output_path

def batch_generate(prompts, output_dir, max_workers=4):
    Path(output_dir).mkdir(exist_ok=True)

    tasks = [
        (prompt, f"{output_dir}/image_{i:04d}.png")
        for i, prompt in enumerate(prompts)
    ]

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(
            executor.map(lambda x: generate_image(*x), tasks),
            total=len(tasks)
        ))

    return results

# 生成 100 张图像
prompts = [f"A beautiful landscape, style {i}" for i in range(100)]
batch_generate(prompts, "./outputs", max_workers=4)
```

### ComfyUI 带队列的批处理

```python
import json
import urllib.request
import time
from pathlib import Path

SERVER = "server:8188"

def queue_prompt(workflow):
    data = json.dumps({"prompt": workflow}).encode('utf-8')
    req = urllib.request.Request(f"http://{SERVER}/prompt", data=data)
    return json.loads(urllib.request.urlopen(req).read())

def get_history(prompt_id):
    with urllib.request.urlopen(f"http://{SERVER}/history/{prompt_id}") as response:
        return json.loads(response.read())

def batch_generate_comfyui(prompts, base_workflow_path, output_dir):
    Path(output_dir).mkdir(exist_ok=True)

    # 加载基础工作流
    with open(base_workflow_path) as f:
        base_workflow = json.load(f)

    prompt_ids = []

    # 将所有提示入队
    for i, prompt in enumerate(prompts):
        workflow = base_workflow.copy()
        # 修改提示节点（根据需要调整节点 ID）
        workflow["6"]["inputs"]["text"] = prompt
        # 设置输出文件名
        workflow["9"]["inputs"]["filename_prefix"] = f"batch_{i:04d}"

        result = queue_prompt(workflow)
        prompt_ids.append(result['prompt_id'])
        print(f"已排队 {i+1}/{len(prompts)}")

    # 等待完成
    print("等待生成...")
    completed = set()
    while len(completed) < len(prompt_ids):
        for pid in prompt_ids:
            if pid not in completed:
                history = get_history(pid)
                if pid in history:
                    completed.add(pid)
                    print(f"已完成 {len(completed)}/{len(prompt_ids)}")
        time.sleep(1)

    print("全部完成！")
```

### FLUX 批处理

```python
import torch
from diffusers import FluxPipeline
from pathlib import Path
from tqdm import tqdm

# 仅加载一次模型
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16
)
pipe.to("cuda")

def batch_generate_flux(prompts, output_dir, batch_size=4):
    Path(output_dir).mkdir(exist_ok=True)

    for i in tqdm(range(0, len(prompts), batch_size)):
        batch_prompts = prompts[i:i + batch_size]

        # 生成批次
        images = pipe(
            batch_prompts,
            height=1024,
            width=1024,
            num_inference_steps=4,
            guidance_scale=0.0
        ).images

        # 保存
        for j, img in enumerate(images):
            img.save(f"{output_dir}/image_{i+j:04d}.png")

# 以 4 为批次大小生成 100 张图像
prompts = [f"A {animal} in a forest" for animal in ["cat", "dog", "fox"] * 34]
batch_generate_flux(prompts, "./flux_outputs", batch_size=4)
```

***

## 音频批处理

### Whisper 批量转录

```python
import whisper
from pathlib import Path
from tqdm import tqdm
import json

model = whisper.load_model("large-v3")

def batch_transcribe(audio_files, output_dir):
    Path(output_dir).mkdir(exist_ok=True)
    results = {}

    for audio_path in tqdm(audio_files):
        try:
            result = model.transcribe(str(audio_path))

            results[audio_path.name] = {
                'text': result['text'],
                'language': result['language'],
                'segments': result['segments']
            }

            # 保存单个转录结果
            output_file = Path(output_dir) / f"{audio_path.stem}.json"
            with open(output_file, 'w') as f:
                json.dump(results[audio_path.name], f, indent=2)

        except Exception as e:
            print(f"处理 {audio_path} 时出错：{e}")
            results[audio_path.name] = {'error': str(e)}

    # 保存合并结果
    with open(f"{output_dir}/all_transcripts.json", 'w') as f:
        json.dump(results, f, indent=2)

    return results

# 转录目录中所有音频文件
audio_files = list(Path("./audio").glob("*.mp3"))
results = batch_transcribe(audio_files, "./transcripts")
```

### 并行 Whisper（多 GPU）

```python
import whisper
from concurrent.futures import ProcessPoolExecutor
import torch

def transcribe_on_gpu(args):
    audio_path, gpu_id = args
    torch.cuda.set_device(gpu_id)
    model = whisper.load_model("large-v3", device=f"cuda:{gpu_id}")
    result = model.transcribe(audio_path)
    return audio_path, result['text']

def parallel_transcribe(audio_files, num_gpus=2):
    # 在 GPU 之间分配文件
    tasks = [(f, i % num_gpus) for i, f in enumerate(audio_files)]

    with ProcessPoolExecutor(max_workers=num_gpus) as executor:
        results = list(executor.map(transcribe_on_gpu, tasks))

    return dict(results)
```

***

## 视频批处理

### 批量视频生成（SVD）

```python
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from pathlib import Path
from tqdm import tqdm
import torch

pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt",
    torch_dtype=torch.float16,
    variant="fp16"
)
pipe.to("cuda")

def batch_generate_videos(image_paths, output_dir):
    Path(output_dir).mkdir(exist_ok=True)

    for img_path in tqdm(image_paths):
        try:
            image = load_image(str(img_path))
            image = image.resize((1024, 576))

            frames = pipe(
                image,
                num_frames=25,
                decode_chunk_size=8
            ).frames[0]

            output_path = Path(output_dir) / f"{img_path.stem}.mp4"
            export_to_video(frames, str(output_path), fps=7)

        except Exception as e:
            print(f"处理 {img_path} 时出错：{e}")

# 处理所有图像
images = list(Path("./input_images").glob("*.png"))
batch_generate_videos(images, "./output_videos")
```

***

## 数据管道模式

### 生产者-消费者 模式

```python
import asyncio
from asyncio import Queue

async def producer(queue, items):
    """向队列添加项"""
    for item in items:
        await queue.put(item)
    # 发出完成信号
    for _ in range(NUM_WORKERS):
        await queue.put(None)

async def consumer(queue, results, worker_id):
    """从队列中处理项"""
    while True:
        item = await queue.get()
        if item is None:
            break

        try:
            result = await process_item(item)
            results.append(result)
        except Exception as e:
            print(f"工作线程 {worker_id} 出错：{e}")

        queue.task_done()

async def run_pipeline(items, num_workers=5):
    queue = Queue(maxsize=100)
    results = []

    # 启动工作线程
    workers = [
        asyncio.create_task(consumer(queue, results, i))
        for i in range(num_workers)
    ]

    # 启动生产者
    await producer(queue, items)

    # 等待完成
    await asyncio.gather(*workers)

    return results

NUM_WORKERS = 5
items = list(range(1000))
results = asyncio.run(run_pipeline(items))
```

### Map-Reduce 模式

```python
from concurrent.futures import ProcessPoolExecutor
from functools import reduce

def map_function(item):
    """处理单个项"""
    # 您的处理逻辑
    return process(item)

def reduce_function(results):
    """合并结果"""
    return combine(results)

def map_reduce(items, num_workers=4):
    # Map 阶段
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        mapped = list(executor.map(map_function, items))

    # Reduce 阶段
    result = reduce_function(mapped)

    return result
```

***

## 优化建议

### 1. 合理设置并发

```python
# LLM：匹配 vLLM 的最大批量大小
max_concurrent = 10  # vLLM 默认

# 图像生成：根据显存设置 1-4
max_concurrent = 2  # SD WebUI
max_concurrent = 4  # 在 RTX 4090 上的 FLUX

# 转录：每个 GPU 一个
max_concurrent = num_gpus
```

### 2. 调整批量大小

```python
# 过小：未充分利用 GPU
# 过大：OOM（显存溢出）错误

# 图像生成批量大小：
# RTX 3060：batch_size = 1
# RTX 3090：batch_size = 2-4
# RTX 4090：batch_size = 4-8
# A100：batch_size = 8-16
```

### 3. 内存管理

```python
import gc
import torch

def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# 在大批次之间调用
for batch in batches:
    process_batch(batch)
    clear_memory()
```

### 4. 保存中间结果

```python
# 对于长时间运行的作业，始终进行检查点保存
CHECKPOINT_INTERVAL = 100

for i, item in enumerate(items):
    results.append(process(item))

    if i % CHECKPOINT_INTERVAL == 0:
        save_checkpoint(results, i)
```

***

## 成本优化

### 在运行前估算

```python
def estimate_cost(num_items, time_per_item_sec, hourly_rate):
    total_hours = (num_items * time_per_item_sec) / 3600
    total_cost = total_hours * hourly_rate
    return total_hours, total_cost

# 示例：在 RTX 4090 上以每张 3 秒处理 10,000 张图像
hours, cost = estimate_cost(10000, 3, 0.10)
print(f"估计：{hours:.1f} 小时，${cost:.2f}")
# 输出：Estimated: 8.3 hours, $0.83
```

### 使用抢占实例（Spot Instances）

* 便宜 30-50%
* 适合批处理作业（可中断）
* 频繁保存检查点

### 非高峰时段处理

* 在需求低的时段排队作业
* 通常更好的 GPU 可用性
* 可能更低的抢占价格

***

## 下一步

* [API 集成](https://docs.clore.ai/guides/guides_v2-zh/gao-ji/api-integration) - 构建您的 API
* [多 GPU 设置](https://docs.clore.ai/guides/guides_v2-zh/gao-ji/multi-gpu-setup) - 向上扩展
* [成本计算器](https://docs.clore.ai/guides/guides_v2-zh/kuai-su-ru-men/cost-calculator) - 估算成本
