Skip to content

知识点卡片:初始化策略

基本信息

属性内容
知识点权重初始化 (Xavier/He/Kaiming)
掌握程度★★★★☆
学习优先级P1
预估时间4小时
面试频率★★★☆☆

核心原理

好的初始化应该保持前向传播的激活值和反向传播的梯度方差不变:

前向方差要求:Var(a_l) ≈ Var(a_{l-1})
  → n_in * Var(w) = 1 → Var(w) = 1/n_in

反向方差要求:Var(∂L/∂a_{l-1}) ≈ Var(∂L/∂a_l)
  → n_out * Var(w) = 1 → Var(w) = 1/n_out

折中(Xavier/Glorot):
  Var(w) = 2 / (n_in + n_out)
  均匀分布范围:[-√(6/(n_in+n_out)), √(6/(n_in+n_out))]

ReLU修正(He/Kaiming):
  ReLU将一半输入置零 → 方差减半
  补偿:Var(w) = 2 / n_in
  正态分布:N(0, √(2/n_in))

代码实现

python
import torch
import torch.nn as nn
import math

def xavier_uniform_(tensor, gain=1.0):
    """Xavier均匀初始化"""
    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    bound = math.sqrt(3.0) * std
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

def kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu'):
    """He/Kaiming正态初始化"""
    fan = nn.init._calculate_correct_fan(tensor, mode)
    gain = nn.init.calculate_gain(nonlinearity)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)

# PyTorch内置
nn.init.xavier_uniform_(weight)
nn.init.xavier_normal_(weight)
nn.init.kaiming_uniform_(weight, nonlinearity='relu')
nn.init.kaiming_normal_(weight, nonlinearity='relu')

# 应用于模型
def init_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

各种初始化对比

初始化适用激活函数方差说明
全零任何0❌ 对称性问题,所有神经元学相同
标准正态任何1❌ 多层后方差爆炸/消失
Xavier/Glorottanh/sigmoid2/(n_in+n_out)✅ 对称激活函数
He/KaimingReLU/LeakyReLU2/n_in✅ ReLU修正
正交初始化--✅ 保持向量长度

面试高频问题

Q1: 为什么不能初始化为全零?

: 如果所有权重初始化为0,所有神经元计算相同输出、接收相同梯度、进行相同更新——它们永远无法分化(对称性问题)。偏置可以初始化为0,因为每个神经元的输入不同,对称性会被打破。

Q2: Xavier和He初始化的区别?

: Xavier假设激活函数是线性的(tanh近似线性区域),推导出 Var(w)=2/(n_in+n_out)。但ReLU将一半输出截为0,方差实际减半。He初始化乘以2来补偿:Var(w)=2/n_in。


相关知识点