Skip to content

知识点卡片:LLM推理优化技术

基本信息

属性内容
知识点LLM推理优化 (KV Cache/Flash Attention/量化/vLLM)
掌握程度★★★★★
学习优先级P0
预估时间8小时
面试频率★★★★★

技术全景

推理优化技术栈
├── 注意力优化
│   ├── KV Cache
│   ├── Flash Attention
│   ├── PagedAttention (vLLM)
│   └── GQA/MQA
├── 量化
│   ├── INT8 / INT4
│   ├── GPTQ / AWQ
│   └── FP8 / FP16
├── 并行
│   ├── Continuous Batching
│   ├── Tensor Parallelism
│   └── Pipeline Parallelism
└── 其他
    ├── 投机解码 (Speculative Decoding)
    └── 算子融合 (Kernel Fusion)

1. KV Cache

python
"""
自回归生成时,每个新token需要与所有历史token的K、V计算注意力。

无cache方案:
每次重新计算所有历史的K、V → O(n²) 重复计算

KV Cache:
缓存历史的K、V,新token只需:
1. 计算新token的Q、K、V
2. Q与[历史的K; 新K]计算注意力
3. 新V与[历史的V; 新V]拼接
→ 将O(n³)降到O(n²)

显存占用:2 * n_layers * n_heads * seq_len * d_head * dtype_size
"""

# KV Cache示意
class AttentionWithKVCache(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_qkv = nn.Linear(d_model, 3 * d_model)

    def forward(self, x, past_kv=None):
        B, T, C = x.shape
        qkv = self.W_qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1)  # 拼接历史K
            v = torch.cat([past_v, v], dim=1)  # 拼接历史V

        # 只计算最后一个位置需要的新注意力
        # ...
        return output, (k, v)  # 返回新的KV Cache

2. Flash Attention

核心思想:分块计算 + Online Softmax

标准Attention的问题:
- 需要O(N²)显存存储注意力矩阵
- 显存带宽成为瓶颈

Flash Attention的解决方案:
1. 将Q、K、V分成小块
2. 逐块加载到SRAM中计算
3. 使用Online Softmax避免存储完整N×N矩阵
4. 计算结果直接写回HBM

结果:
- 显存:O(N²) → O(N)
- 速度:2-4x加速
- 精度:数学等价(非近似)

3. PagedAttention (vLLM)

KV Cache的分页管理:

问题:传统KV Cache是连续分配的,导致:
- 内部碎片(预分配过大)
- 外部碎片(无法复用)
- 显存利用率仅20-40%

PagedAttention:
- 类似操作系统虚拟内存
- KV Cache分块存储在非连续物理页中
- 按需分配,动态增长
- 多请求间可共享相同prefix的KV Cache(如system prompt)

效果:显存利用率提升到接近100%,吞吐量提升2-4倍

4. 量化技术

方法精度压缩比特点
FP1616-bit2x无损,标配
INT88-bit4x小精度损失
INT44-bit8x中等精度损失
GPTQ4-bit8x基于OBQ,效果好
AWQ4-bit8x保留显著权重,效果好
python
# HuggingFace 加载量化模型
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# 4-bit量化加载
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
    "model_name",
    quantization_config=bnb_config,
    device_map="auto"
)

5. 投机解码 (Speculative Decoding)

问题:自回归生成是串行的(每次1 token),慢

投机解码:
1. 用小型快速模型(draft model)生成k个候选token
2. 用大模型一次性验证这k个token
3. 接受匹配的token,修正第一个不匹配的token
4. 重复

加速原理:
- 每次生成k个token(而非1个)
- 大模型只需1次前向验证(而非k次)
- 加速2-3x,数学等价(输出完全相同)

面试高频问题

Q1: KV Cache为什么占显存这么大?

: 以Llama-2-7B为例(FP16):

每层KV Cache = 2(batch) × 32(heads) × seq_len × 128(d_head) × 2(bytes)
              = 16384 × seq_len bytes

32层 × 16384 × 4096(seq_len) ≈ 2GB per batch

当batch=8, seq_len=4096时:KV Cache ≈ 16GB
接近甚至超过模型参数本身的显存(14GB)

Q2: Flash Attention为什么是精确的而非近似的?

: Flash Attention使用分块计算+Online Softmax,这是数值计算技巧而非数学近似。关键在于:

  • Softmax可以通过两个pass来精确计算(先算max,再算exp和sum)
  • 分块在每个pass中使用正确的修正因子
  • 最终结果与标准Attention数学等价

相关知识点