
RLHF 与大语言模型对齐
现代大语言模型通过 RLHF、DPO 和宪法式 AI 等技术实现与人类偏好的对齐。
RLHF 概述
预训练 LLM → SFT(监督微调)→ 奖励模型 → PPO 训练 → 对齐后的 LLM

直接偏好优化(DPO)
DPO 简化了 RLHF,直接在偏好对上优化,无需单独的奖励模型。
from trl import DPOTrainer, DPOConfig
from datasets import Dataset
# 偏好数据集格式
preference_data = [
{
"prompt": "What is the capital of France?",
"chosen": "The capital of France is Paris, which has been the country's capital since the 12th century.",
"rejected": "France's capital is Lyon, a major city in southeastern France.",
},
{
"prompt": "Explain recursion simply.",
"chosen": "Recursion is when a function calls itself to solve smaller versions of the same problem, like Russian dolls.",
"rejected": "Recursion is a programming technique involving self-referential function invocations.",
},
]
preference_ds = Dataset.from_list(preference_data)
# DPO 训练
dpo_config = DPOConfig(
output_dir="./dpo-output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-7,
beta=0.1, # 控制与参考模型的偏差
bf16=True,
logging_steps=10,
)
dpo_trainer = DPOTrainer(
model=model,
ref_model=None, # 将使用模型的副本作为参考
args=dpo_config,
train_dataset=preference_ds,
processing_class=tokenizer,
)
dpo_trainer.train()
奖励模型训练
from transformers import AutoModelForSequenceClassification
from trl import RewardTrainer, RewardConfig
reward_model = AutoModelForSequenceClassification.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
num_labels=1, # 单个奖励分数
)
reward_config = RewardConfig(
output_dir="./reward-model",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=1e-5,
bf16=True,
)
reward_trainer = RewardTrainer(
model=reward_model,
args=reward_config,
train_dataset=preference_ds,
processing_class=tokenizer,
)
reward_trainer.train()

使用 Llama Guard 的安全护栏
from transformers import pipeline
# Llama Guard 用于内容审核
safety_checker = pipeline(
"text-generation",
model="meta-llama/LlamaGuard-7b",
device_map="auto",
)
def check_safety(user_message: str, ai_response: str = None) -> dict:
prompt = f"""[INST] Task: Check if there is unsafe content in the message.
<BEGIN CONVERSATION>
User: {user_message}
{f"Agent: {ai_response}" if ai_response else ""}
<END CONVERSATION>
Provide your safety assessment. [/INST]"""
result = safety_checker(prompt, max_new_tokens=100)[0]["generated_text"]
is_safe = "safe" in result.lower() and "unsafe" not in result.lower()
return {
"is_safe": is_safe,
"assessment": result.split("[/INST]")[-1].strip(),
}
宪法式 AI 方法
from langchain_openai import ChatOpenAI
CONSTITUTION = """
1. Do not help with illegal activities
2. Do not generate harmful or hateful content
3. Be honest about being an AI
4. Respect user privacy
5. Do not spread misinformation
"""
def constitutional_check(response: str) -> tuple[str, bool]:
"""根据宪法原则检查并修改回复。"""
llm = ChatOpenAI(model="gpt-4o-mini")
critique_prompt = f"""Here is an AI response: "{response}"
Constitutional principles:
{CONSTITUTION}
Does this response violate any principle? If yes, explain which one and rewrite the response to comply.
If no violations, just say "COMPLIANT".
Format: COMPLIANT or VIOLATION: [principle] REVISION: [revised response]"""
critique = llm.invoke(critique_prompt).content
if critique.startswith("COMPLIANT"):
return response, True
if "REVISION:" in critique:
revised = critique.split("REVISION:")[-1].strip()
return revised, False
return response, True

输出格式控制与护栏
from guardrails import Guard
from guardrails.hub import ToxicLanguage, ProfanityFree
guard = Guard().use_many(
ToxicLanguage(threshold=0.5, on_fail="exception"),
ProfanityFree(on_fail="fix"),
)
def safe_generate(prompt: str) -> str:
response, validated, *rest = guard(
llm_api=client.chat.completions.create,
messages=[{"role": "user", "content": prompt}],
model="gpt-4o-mini",
)
return validated
对齐评估
| 方法 | 用例 | 优缺点 |
|---|---|---|
| RLHF | 通用对齐 | 强大但复杂 |
| DPO | 更简单的训练 | 无需奖励模型 |
| PPO | 细粒度控制 | 训练不稳定 |
| KTO | 二元反馈 | 数据收集更容易 |
| 宪法式 AI | 基于规则 | 可解释性强 |