知识点卡片:初始化策略
基本信息
| 属性 | 内容 |
|---|---|
| 知识点 | 权重初始化 (Xavier/He/Kaiming) |
| 掌握程度 | ★★★★☆ |
| 学习优先级 | P1 |
| 预估时间 | 4小时 |
| 面试频率 | ★★★☆☆ |
核心原理
好的初始化应该保持前向传播的激活值和反向传播的梯度方差不变:
前向方差要求:Var(a_l) ≈ Var(a_{l-1})
→ n_in * Var(w) = 1 → Var(w) = 1/n_in
反向方差要求:Var(∂L/∂a_{l-1}) ≈ Var(∂L/∂a_l)
→ n_out * Var(w) = 1 → Var(w) = 1/n_out
折中(Xavier/Glorot):
Var(w) = 2 / (n_in + n_out)
均匀分布范围:[-√(6/(n_in+n_out)), √(6/(n_in+n_out))]
ReLU修正(He/Kaiming):
ReLU将一半输入置零 → 方差减半
补偿:Var(w) = 2 / n_in
正态分布:N(0, √(2/n_in))代码实现
python
import torch
import torch.nn as nn
import math
def xavier_uniform_(tensor, gain=1.0):
"""Xavier均匀初始化"""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
bound = math.sqrt(3.0) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu'):
"""He/Kaiming正态初始化"""
fan = nn.init._calculate_correct_fan(tensor, mode)
gain = nn.init.calculate_gain(nonlinearity)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std)
# PyTorch内置
nn.init.xavier_uniform_(weight)
nn.init.xavier_normal_(weight)
nn.init.kaiming_uniform_(weight, nonlinearity='relu')
nn.init.kaiming_normal_(weight, nonlinearity='relu')
# 应用于模型
def init_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)各种初始化对比
| 初始化 | 适用激活函数 | 方差 | 说明 |
|---|---|---|---|
| 全零 | 任何 | 0 | ❌ 对称性问题,所有神经元学相同 |
| 标准正态 | 任何 | 1 | ❌ 多层后方差爆炸/消失 |
| Xavier/Glorot | tanh/sigmoid | 2/(n_in+n_out) | ✅ 对称激活函数 |
| He/Kaiming | ReLU/LeakyReLU | 2/n_in | ✅ ReLU修正 |
| 正交初始化 | - | - | ✅ 保持向量长度 |
面试高频问题
Q1: 为什么不能初始化为全零?
答: 如果所有权重初始化为0,所有神经元计算相同输出、接收相同梯度、进行相同更新——它们永远无法分化(对称性问题)。偏置可以初始化为0,因为每个神经元的输入不同,对称性会被打破。
Q2: Xavier和He初始化的区别?
答: Xavier假设激活函数是线性的(tanh近似线性区域),推导出 Var(w)=2/(n_in+n_out)。但ReLU将一半输出截为0,方差实际减半。He初始化乘以2来补偿:Var(w)=2/n_in。