正在加载,请稍候…

RLHF 与宪法式 AI:让大语言模型对齐人类偏好

理解 RLHF、DPO 和宪法式 AI 等大语言模型对齐技术,学习实现偏好学习、奖励建模和生产级 AI 的安全护栏。

RLHF 与宪法式 AI:让大语言模型对齐人类偏好

RLHF 与大语言模型对齐

现代大语言模型通过 RLHF、DPO 和宪法式 AI 等技术实现与人类偏好的对齐。

RLHF 概述

预训练 LLM → SFT(监督微调)→ 奖励模型 → PPO 训练 → 对齐后的 LLM

RLHF 与宪法式 AI:让大语言模型对齐人类偏好 示意图

直接偏好优化(DPO)

DPO 简化了 RLHF,直接在偏好对上优化,无需单独的奖励模型。

from trl import DPOTrainer, DPOConfig
from datasets import Dataset

# 偏好数据集格式
preference_data = [
    {
        "prompt": "What is the capital of France?",
        "chosen": "The capital of France is Paris, which has been the country's capital since the 12th century.",
        "rejected": "France's capital is Lyon, a major city in southeastern France.",
    },
    {
        "prompt": "Explain recursion simply.",
        "chosen": "Recursion is when a function calls itself to solve smaller versions of the same problem, like Russian dolls.",
        "rejected": "Recursion is a programming technique involving self-referential function invocations.",
    },
]

preference_ds = Dataset.from_list(preference_data)

# DPO 训练
dpo_config = DPOConfig(
    output_dir="./dpo-output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-7,
    beta=0.1,  # 控制与参考模型的偏差
    bf16=True,
    logging_steps=10,
)

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,  # 将使用模型的副本作为参考
    args=dpo_config,
    train_dataset=preference_ds,
    processing_class=tokenizer,
)

dpo_trainer.train()

奖励模型训练

from transformers import AutoModelForSequenceClassification
from trl import RewardTrainer, RewardConfig

reward_model = AutoModelForSequenceClassification.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    num_labels=1,  # 单个奖励分数
)

reward_config = RewardConfig(
    output_dir="./reward-model",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=1e-5,
    bf16=True,
)

reward_trainer = RewardTrainer(
    model=reward_model,
    args=reward_config,
    train_dataset=preference_ds,
    processing_class=tokenizer,
)

reward_trainer.train()

RLHF 与宪法式 AI:让大语言模型对齐人类偏好 示意图

使用 Llama Guard 的安全护栏

from transformers import pipeline

# Llama Guard 用于内容审核
safety_checker = pipeline(
    "text-generation",
    model="meta-llama/LlamaGuard-7b",
    device_map="auto",
)

def check_safety(user_message: str, ai_response: str = None) -> dict:
    prompt = f"""[INST] Task: Check if there is unsafe content in the message.

<BEGIN CONVERSATION>
User: {user_message}
{f"Agent: {ai_response}" if ai_response else ""}
<END CONVERSATION>

Provide your safety assessment. [/INST]"""

    result = safety_checker(prompt, max_new_tokens=100)[0]["generated_text"]
    is_safe = "safe" in result.lower() and "unsafe" not in result.lower()

    return {
        "is_safe": is_safe,
        "assessment": result.split("[/INST]")[-1].strip(),
    }

宪法式 AI 方法

from langchain_openai import ChatOpenAI

CONSTITUTION = """
1. Do not help with illegal activities
2. Do not generate harmful or hateful content
3. Be honest about being an AI
4. Respect user privacy
5. Do not spread misinformation
"""

def constitutional_check(response: str) -> tuple[str, bool]:
    """根据宪法原则检查并修改回复。"""
    llm = ChatOpenAI(model="gpt-4o-mini")

    critique_prompt = f"""Here is an AI response: "{response}"

Constitutional principles:
{CONSTITUTION}

Does this response violate any principle? If yes, explain which one and rewrite the response to comply.
If no violations, just say "COMPLIANT".

Format: COMPLIANT or VIOLATION: [principle] REVISION: [revised response]"""

    critique = llm.invoke(critique_prompt).content

    if critique.startswith("COMPLIANT"):
        return response, True

    if "REVISION:" in critique:
        revised = critique.split("REVISION:")[-1].strip()
        return revised, False

    return response, True

RLHF 与宪法式 AI:让大语言模型对齐人类偏好 示意图

输出格式控制与护栏

from guardrails import Guard
from guardrails.hub import ToxicLanguage, ProfanityFree

guard = Guard().use_many(
    ToxicLanguage(threshold=0.5, on_fail="exception"),
    ProfanityFree(on_fail="fix"),
)

def safe_generate(prompt: str) -> str:
    response, validated, *rest = guard(
        llm_api=client.chat.completions.create,
        messages=[{"role": "user", "content": prompt}],
        model="gpt-4o-mini",
    )
    return validated

对齐评估

方法 用例 优缺点
RLHF 通用对齐 强大但复杂
DPO 更简单的训练 无需奖励模型
PPO 细粒度控制 训练不稳定
KTO 二元反馈 数据收集更容易
宪法式 AI 基于规则 可解释性强