正在加载,请稍候…

使用 LoRA 和 QLoRA 微调大语言模型:实用指南

高效使用 LoRA 和 QLoRA 微调大语言模型,涵盖数据集准备、训练、评估及使用 Hugging Face 和 vLLM 部署。

使用 LoRA 和 QLoRA 微调大语言模型:实用指南

使用 LoRA 和 QLoRA 微调大语言模型

LoRA 可将 7B 模型的微调显存需求从约 28GB 降至约 6GB。

QLoRA 配置

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 20,185,088 || 0.25%

使用 LoRA 和 QLoRA 微调大语言模型:实用指南插图

SFTTrainer 训练

from trl import SFTTrainer, SFTConfig

args = SFTConfig(
    output_dir="./checkpoints",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    bf16=True,
    max_seq_length=2048,
    report_to="wandb",
)

trainer = SFTTrainer(
    model=model, args=args,
    train_dataset=train_ds, eval_dataset=eval_ds,
    processing_class=tokenizer,
)
trainer.train()
trainer.model.save_pretrained("./lora-adapter")

使用 LoRA 和 QLoRA 微调大语言模型:实用指南插图

合并与部署

from peft import PeftModel

base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.float16, device_map="cpu"
)
peft = PeftModel.from_pretrained(base, "./lora-adapter")
merged = peft.merge_and_unload()
merged.save_pretrained("./merged-model", safe_serialization=True)

使用 LoRA 和 QLoRA 微调大语言模型:实用指南插图

使用 vLLM 提供服务

python -m vllm.entrypoints.openai.api_server \
  --model ./merged-model --port 8000 --quantization awq

超参数指南

参数 小数据集 大数据集
Rank (r) 8-16 32-64
Alpha 16-32 64-128
LR 1e-4 2e-4
Epochs 5-10 2-3