正在加载,请稍候…

PyTorch 计算机视觉:从训练到生产部署

完整指南,涵盖使用 PyTorch 进行图像分类、目标检测、模型优化(TorchScript、ONNX、TensorRT)及生产部署。

PyTorch 计算机视觉:从训练到生产部署

PyTorch 计算机视觉:生产指南

使用 ResNet 进行迁移学习

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

def build_classifier(num_classes: int, backbone: str = "resnet50") -> nn.Module:
    model = getattr(models, backbone)(weights="IMAGENET1K_V1")

    # 冻结早期层
    for param in list(model.parameters())[:-20]:
        param.requires_grad = False

    # 替换最后一层
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes),
    )
    return model

# 数据增强流水线
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

PyTorch 计算机视觉:从训练到生产部署 插图

混合精度训练循环

from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

def train_epoch(model, loader, optimizer, scheduler, scaler, device):
    model.train()
    total_loss, correct = 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():  # 混合精度
            outputs = model(images)
            loss = nn.CrossEntropyLoss()(outputs, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    return total_loss / len(loader), correct / len(loader.dataset)

def train(num_classes=10, epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_classifier(num_classes).to(device)

    train_ds = datasets.ImageFolder("data/train", transform=train_transforms)
    val_ds = datasets.ImageFolder("data/val", transform=val_transforms)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

    optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=epochs)
    scaler = GradScaler()

    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, device)
        val_acc = evaluate(model, val_loader, device)
        print(f"Epoch {epoch+1}: loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")

    return model

PyTorch 计算机视觉:从训练到生产部署 插图

导出为 ONNX

def export_onnx(model, output_path: str = "model.onnx"):
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224)

    torch.onnx.export(
        model, dummy_input, output_path,
        export_params=True,
        opset_version=17,
        input_names=["image"],
        output_names=["logits"],
        dynamic_axes={"image": {0: "batch_size"}, "logits": {0: "batch_size"}},
    )
    print(f"Exported to {output_path}")

# 验证 ONNX
import onnxruntime as ort
import numpy as np

def validate_onnx(model_path: str, input_image: np.ndarray) -> np.ndarray:
    sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
    outputs = sess.run(None, {"image": input_image.astype(np.float32)})
    return outputs[0]

PyTorch 计算机视觉:从训练到生产部署 插图

TensorRT 优化

import tensorrt as trt

def build_trt_engine(onnx_path: str, engine_path: str, fp16: bool = True):
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)

    with open(onnx_path, "rb") as f:
        parser.parse(f.read())

    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 30  # 1GB

    if fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    engine = builder.build_engine(network, config)

    with open(engine_path, "wb") as f:
        f.write(engine.serialize())
    print(f"TensorRT engine saved: {engine_path}")

生产推理服务器

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io

app = FastAPI()

class Predictor:
    def __init__(self, model_path: str, class_names: list):
        self.session = ort.InferenceSession(model_path)
        self.class_names = class_names
        self.transform = val_transforms

    def predict(self, image: Image.Image) -> dict:
        tensor = self.transform(image).unsqueeze(0).numpy()
        logits = self.session.run(None, {"image": tensor})[0]
        probs = torch.softmax(torch.tensor(logits[0]), dim=0).numpy()
        top_idx = probs.argsort()[::-1][:5]
        return {
            "predictions": [
                {"class": self.class_names[i], "confidence": float(probs[i])}
                for i in top_idx
            ]
        }

predictor = Predictor("model.onnx", class_names=["cat", "dog", "bird"])

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    return predictor.predict(image)

性能基准测试

格式 延迟 (ms) 吞吐量 (img/s)
PyTorch FP32 45ms 22
PyTorch FP16 28ms 36
ONNX Runtime 18ms 56
TensorRT FP16 8ms 125