Skip to content

知识点卡片:Self-Attention 与 Multi-Head Attention

基本信息

属性内容
知识点Self-Attention / Multi-Head Attention
掌握程度★★★★★
学习优先级P0
面试必考
必须手写

核心原理

Self-Attention(自注意力)通过Query-Key-Value机制实现序列内各位置的交互:

每个位置 i 对所有位置 j 计算注意力:
- Query (Q): 我想查询什么
- Key (K): 我包含什么信息
- Value (V): 我提供什么信息

注意力分数 = Q · K^T / √d_k
注意力权重 = softmax(注意力分数)
输出 = 注意力权重 · V

数学公式

Attention(Q, K, V) = softmax(QK^T / √d_k) · V

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

PyTorch实现

1. 标准Self-Attention

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """标准Self-Attention"""
    def __init__(self, d_model):
        super().__init__()
        self.d_k = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        B, T, C = x.shape

        # 线性变换得到Q, K, V
        Q = self.W_q(x)  # (B, T, C)
        K = self.W_k(x)
        V = self.W_v(x)

        # 计算注意力分数
        # (B, T, C) @ (B, C, T) -> (B, T, T)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # softmax得到注意力权重
        attn_weights = F.softmax(scores, dim=-1)

        # 加权求和
        # (B, T, T) @ (B, T, C) -> (B, T, C)
        context = torch.matmul(attn_weights, V)

        # 输出变换
        output = self.W_o(context)
        return output, attn_weights

2. Multi-Head Attention

python
class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        # 合并为一个大线性层(效率更高)
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # 同时计算Q, K, V
        qkv = self.W_qkv(x).reshape(B, T, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, T, Dk)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # 计算注意力
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        context = attn @ V  # (B, H, T, Dk)

        # 合并多头
        context = context.transpose(1, 2).reshape(B, T, C)
        return self.W_o(context), attn

3. Flash Attention(简化版原理)

python
def flash_attention(Q, K, V, scale=1.0):
    """
    Flash Attention的核心思想:分块计算 + 数值稳定
    实际实现需要CUDA kernel,这里展示原理
    """
    # 标准attention: O(N²) 显存
    # Flash attention: O(N) 显存,通过分块避免 materialize 整个 attention matrix

    B, H, N, D = Q.shape
    block_size = 64

    # 分块计算
    output = torch.zeros_like(Q)

    # 只需保存 row_max 和 row_sum,实现 online softmax
    for i in range(0, N, block_size):
        for j in range(0, N, block_size):
            # 计算当前块
            q_block = Q[..., i:i+block_size, :]
            k_block = K[..., j:j+block_size, :]
            v_block = V[..., j:j+block_size, :]

            # ... 分块注意力计算 ...
            # 利用数值稳定的 online softmax
            pass

    return output

复杂度分析

操作时间复杂度空间复杂度
标准AttentionO(N²·d)O(N²)
Flash AttentionO(N²·d)O(N·d)
N = 序列长度
d = 每头维度
h = 注意力头数

标准Attention需要保存 N×N 的注意力矩阵 → O(N²) 空间
Flash Attention通过分块计算,只需保存对角线元素 → O(N) 空间

面试高频问题

Q1: Self-Attention的计算过程(手写)

python
def attention(Q, K, V):
    # 1. 计算注意力分数
    scores = Q @ K.T / math.sqrt(d_k)  # (T, T)

    # 2. softmax归一化
    attn_weights = F.softmax(scores, dim=-1)  # (T, T)

    # 3. 加权求和
    output = attn_weights @ V  # (T, d)

    return output, attn_weights

Q2: 为什么要除以 √d_k?

  1. 防止梯度消失:当d_k较大时,Q·K的点积可能很大,使softmax进入饱和区,梯度接近0
  2. 方差归一化:假设Q和K的每个元素是独立随机变量,均值为0,方差为1,则Q·K的方差为d_k,除以√d_k使其方差恢复到1

Q3: Multi-Head Attention vs Single Head?

方面Single HeadMulti-Head
表示子空间只有一个多个
并行计算简单可并行
参数量较少略多(但每头维度减小)
效果可能欠拟合通常更好
Multi-Head的本质:让每个头学习不同的表示子空间

例如:
- 头1可能关注语法结构
- 头2可能关注语义相似性
- 头3可能关注位置信息

Q4: Self-Attention vs RNN/LSTM?

方面Self-AttentionRNN/LSTM
距离O(1)O(N)(距离远梯度消失)
并行易并行难并行(顺序依赖)
捕获依赖全局局部(长期依赖困难)
计算量O(N²·d)O(N·d)
内存O(N²)O(N·d)

Q5: Pre-LN vs Post-LN?

Post-LN: 
    x → Attention → + → LN → FFN → + → LN → output
    (层Norm在残差之后,更难训练,需要warmup)

Pre-LN:
    x → LN → Attention → + → LN → FFN → + → output
    (层Norm在残差之前,更稳定,已成主流)
方面Post-LNPre-LN
训练稳定性差(需要warmup)
理论效果可能更好略差
实际使用主流(Llama/Qwen等都用Pre-LN)

代码练习

python
import torch

# 验证Multi-Head Attention实现
d_model = 256
num_heads = 8
batch_size = 2
seq_len = 10

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

output, attn = mha(x)
print(f"Output shape: {output.shape}")  # (2, 10, 256)
print(f"Attention shape: {attn.shape}")  # (2, 8, 10, 10)

# 检查注意力权重
print(f"Attention weights sum: {attn[0, 0].sum().item():.4f}")  # 应接近1.0

相关知识点