Skip to content

知识点卡片:梯度裁剪

基本信息

属性内容
知识点梯度裁剪 (Gradient Clipping)
掌握程度★★★☆☆
学习优先级P1
预估时间2小时
面试频率★★★☆☆

核心原理

梯度裁剪防止梯度爆炸,将过大梯度限制在阈值内:

clip_grad_norm: 按L2范数裁剪
if ‖g‖ > max_norm:
    g = g * max_norm / ‖g‖

clip_grad_value: 按元素值裁剪
g = clamp(g, -max_value, max_value)

代码实现

python
import torch
import torch.nn as nn

# PyTorch内置
model = nn.Linear(10, 5)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环
for batch in dataloader:
    optimizer.zero_grad()
    loss = loss_fn(model(batch['x']), batch['y'])
    loss.backward()

    # 梯度裁剪(建议放在backward和step之间)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    # 或按值裁剪
    # torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

    optimizer.step()

# 从零实现
def clip_grad_norm(parameters, max_norm):
    """手动实现梯度裁剪"""
    total_norm = 0.0
    for p in parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(clip_coef)

    return total_norm

在哪些模型中需要

模型类型需要程度原因
RNN/LSTM★★★★★循环结构容易梯度爆炸
Transformer★★★★☆长序列时梯度可能爆炸
LLM训练★★★★★大规模训练标配
CNN★★☆☆☆通常不需要

面试高频问题

Q1: 梯度裁剪为什么有效?

: 梯度爆炸时,参数更新过大,可能导致loss变为NaN或发散。裁剪将梯度的最大范数限制在阈值内,相当于在损失曲面上做了投影,保证每次参数更新不会太大。

Q2: clip_grad_norm vs clip_grad_value?

  • clip_grad_norm:保持梯度方向,只限制整体范数(推荐)
  • clip_grad_value:逐元素限制,可能改变梯度方向

相关知识点