
PyTorch 计算机视觉:生产指南
使用 ResNet 进行迁移学习
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
def build_classifier(num_classes: int, backbone: str = "resnet50") -> nn.Module:
model = getattr(models, backbone)(weights="IMAGENET1K_V1")
# 冻结早期层
for param in list(model.parameters())[:-20]:
param.requires_grad = False
# 替换最后一层
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
return model
# 数据增强流水线
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

混合精度训练循环
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
def train_epoch(model, loader, optimizer, scheduler, scaler, device):
model.train()
total_loss, correct = 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with autocast(): # 混合精度
outputs = model(images)
loss = nn.CrossEntropyLoss()(outputs, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
correct += (outputs.argmax(1) == labels).sum().item()
return total_loss / len(loader), correct / len(loader.dataset)
def train(num_classes=10, epochs=20):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_classifier(num_classes).to(device)
train_ds = datasets.ImageFolder("data/train", transform=train_transforms)
val_ds = datasets.ImageFolder("data/val", transform=val_transforms)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=epochs)
scaler = GradScaler()
for epoch in range(epochs):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, device)
val_acc = evaluate(model, val_loader, device)
print(f"Epoch {epoch+1}: loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")
return model

导出为 ONNX
def export_onnx(model, output_path: str = "model.onnx"):
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model, dummy_input, output_path,
export_params=True,
opset_version=17,
input_names=["image"],
output_names=["logits"],
dynamic_axes={"image": {0: "batch_size"}, "logits": {0: "batch_size"}},
)
print(f"Exported to {output_path}")
# 验证 ONNX
import onnxruntime as ort
import numpy as np
def validate_onnx(model_path: str, input_image: np.ndarray) -> np.ndarray:
sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
outputs = sess.run(None, {"image": input_image.astype(np.float32)})
return outputs[0]

TensorRT 优化
import tensorrt as trt
def build_trt_engine(onnx_path: str, engine_path: str, fp16: bool = True):
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(onnx_path, "rb") as f:
parser.parse(f.read())
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_engine(network, config)
with open(engine_path, "wb") as f:
f.write(engine.serialize())
print(f"TensorRT engine saved: {engine_path}")
生产推理服务器
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io
app = FastAPI()
class Predictor:
def __init__(self, model_path: str, class_names: list):
self.session = ort.InferenceSession(model_path)
self.class_names = class_names
self.transform = val_transforms
def predict(self, image: Image.Image) -> dict:
tensor = self.transform(image).unsqueeze(0).numpy()
logits = self.session.run(None, {"image": tensor})[0]
probs = torch.softmax(torch.tensor(logits[0]), dim=0).numpy()
top_idx = probs.argsort()[::-1][:5]
return {
"predictions": [
{"class": self.class_names[i], "confidence": float(probs[i])}
for i in top_idx
]
}
predictor = Predictor("model.onnx", class_names=["cat", "dog", "bird"])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
return predictor.predict(image)
性能基准测试
| 格式 |
延迟 (ms) |
吞吐量 (img/s) |
| PyTorch FP32 |
45ms |
22 |
| PyTorch FP16 |
28ms |
36 |
| ONNX Runtime |
18ms |
56 |
| TensorRT FP16 |
8ms |
125 |