知识点卡片:梯度裁剪
基本信息
| 属性 | 内容 |
|---|---|
| 知识点 | 梯度裁剪 (Gradient Clipping) |
| 掌握程度 | ★★★☆☆ |
| 学习优先级 | P1 |
| 预估时间 | 2小时 |
| 面试频率 | ★★★☆☆ |
核心原理
梯度裁剪防止梯度爆炸,将过大梯度限制在阈值内:
clip_grad_norm: 按L2范数裁剪
if ‖g‖ > max_norm:
g = g * max_norm / ‖g‖
clip_grad_value: 按元素值裁剪
g = clamp(g, -max_value, max_value)代码实现
python
import torch
import torch.nn as nn
# PyTorch内置
model = nn.Linear(10, 5)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for batch in dataloader:
optimizer.zero_grad()
loss = loss_fn(model(batch['x']), batch['y'])
loss.backward()
# 梯度裁剪(建议放在backward和step之间)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 或按值裁剪
# torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
optimizer.step()
# 从零实现
def clip_grad_norm(parameters, max_norm):
"""手动实现梯度裁剪"""
total_norm = 0.0
for p in parameters:
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
if p.grad is not None:
p.grad.data.mul_(clip_coef)
return total_norm在哪些模型中需要
| 模型类型 | 需要程度 | 原因 |
|---|---|---|
| RNN/LSTM | ★★★★★ | 循环结构容易梯度爆炸 |
| Transformer | ★★★★☆ | 长序列时梯度可能爆炸 |
| LLM训练 | ★★★★★ | 大规模训练标配 |
| CNN | ★★☆☆☆ | 通常不需要 |
面试高频问题
Q1: 梯度裁剪为什么有效?
答: 梯度爆炸时,参数更新过大,可能导致loss变为NaN或发散。裁剪将梯度的最大范数限制在阈值内,相当于在损失曲面上做了投影,保证每次参数更新不会太大。
Q2: clip_grad_norm vs clip_grad_value?
- clip_grad_norm:保持梯度方向,只限制整体范数(推荐)
- clip_grad_value:逐元素限制,可能改变梯度方向