• 深度学习——Softmax与交叉熵:从原理到梯度推导
  • 2025-11-18 14:14:21
  • 深入理解Softmax与交叉熵:从原理到梯度推导

    在深度学习中,Softmax函数与交叉熵损失( Cross-Entropy Loss )是分类任务的核心组件。本文将从数学原理出发,推导它们的梯度计算过程,并解释参数更新中涉及的矩阵求导关键点。无论你是刚入门的新手还是希望巩固基础的开发者,这篇博客都将为你提供清晰的洞见。

    1. Softmax函数:从Logits到概率分布

    1.1 定义与公式

    Softmax函数将一组实数( logits )转换为概率分布。对于输入向量 $ z = [z_1, z_2, \dots, z_K] $,其输出为:

    \[s_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}} \quad \text{( 其中 } i=1,2,\dots,K \text{ )}

    \]

    输出满足 $ \sum_{i=1}^K s_i = 1 $,代表每个类别的预测概率。

    1.2 数值稳定性技巧

    实际实现时,为避免指数运算导致数值溢出,通常对输入进行平移:

    \[s_i = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^K e^{z_j - \max(z)}}

    \]

    减去最大值 $ \max(z) $ 后,指数结果范围更合理,且不影响概率分布。

    2. 交叉熵损失函数:衡量预测与真实的差距

    2.1 公式与意义

    对于真实标签 $ y $( one-hot编码 )和预测概率 $ s $,交叉熵损失定义为:

    \[L = -\sum_{i=1}^K y_i \log s_i

    \]

    核心作用:

    当预测概率 $ s_i $ 接近真实标签 $ y_i $ 时,损失趋近于0。

    对错误预测敏感( 梯度大 ),推动模型快速修正。

    2.2 为什么选择交叉熵?

    与Softmax天然匹配:梯度形式简单( 见下文推导 )。

    相比均方误差( MSE ),交叉熵在分类任务中收敛更快。

    3. 梯度推导:从损失到Softmax输入的导数

    3.1 关键公式推导

    假设真实标签为第 $ k $ 类( 即 $ y_k=1 $,其余为0 ),则损失简化为:

    \[L = -\log s_k

    \]

    计算损失对Softmax输入 $ z_i $ 的梯度 $ \frac{\partial L}{\partial z_i} $:

    当 $ i = k $( 真实类别 ):

    \[\frac{\partial L}{\partial z_i} = s_i - 1

    \]

    当 $ i \neq k $( 非真实类别 ):

    \[\frac{\partial L}{\partial z_i} = s_i

    \]

    合并公式:

    \[\frac{\partial L}{\partial z_i} = s_i - y_i

    \]

    其中 $ y_i $ 是真实标签的one-hot编码。

    3.2 直观理解

    梯度是预测概率与真实标签的差值。

    模型通过减少正确类别的概率梯度( $ s_i -1 $ )和增加错误类别的梯度( $ s_i $ )来更新参数。

    4. 反向传播中的矩阵求导:参数如何更新?

    4.1 从标量梯度到矩阵梯度

    神经网络的参数通常是矩阵形式( 如全连接层的权重矩阵 $ W $ )。反向传播中,梯度计算需遵循矩阵维度匹配规则。

    示例:全连接层 $ Z = WX + b $

    输入 $ X $ 维度:$ (n \times m) $

    权重 $ W $ 维度:$ (d \times n) $

    输出 $ Z $ 维度:$ (d \times m) $

    已知损失对 $ Z $ 的梯度 $ \frac{\partial L}{\partial Z} $( 维度 $ d \times m $ ),则:

    \[\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Z} \cdot X^T \quad \text{( 维度 } d \times n \text{ )}

    \]

    \[\frac{\partial L}{\partial b} = \sum_{\text{batch}} \frac{\partial L}{\partial Z} \quad \text{( 沿批次维度求和 )}

    \]

    4.2 维数分析:为什么是 $ X^T $?

    $ \frac{\partial L}{\partial Z} $ 的维度:$ d \times m $

    $ X $ 的维度:$ n \times m $

    矩阵乘法 $ \frac{\partial L}{\partial Z} \cdot X^T $ 确保结果维度为 $ d \times n $,与 $ W $ 一致。

    5. 参数更新:优化器的角色

    5.1 梯度下降更新公式

    权重矩阵的更新通过优化器实现。以基础梯度下降为例:

    \[W_{t+1} = W_t - \eta \cdot \frac{\partial L}{\partial W}

    \]

    其中 $ \eta $ 是学习率。

    5.2 常用优化方法

    SGD( 随机梯度下降 ):直接使用梯度更新。

    Momentum:引入动量项加速收敛。

    Adam:自适应学习率,适合稀疏梯度场景。

    PyTorch代码示例:

    # 定义模型和优化器

    model = nn.Linear(in_features=10, out_features=3)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # 前向传播与损失计算

    outputs = model(inputs)

    loss = F.cross_entropy(outputs, labels)

    # 反向传播与参数更新

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    6. 总结与思考

    6.1 关键点回顾

    Softmax:将logits转换为概率分布。

    交叉熵损失:衡量预测与真实的差异,梯度形式简洁。

    矩阵求导:通过维数分析和链式法则高效计算梯度。

    参数更新:梯度指导优化方向,学习率控制步长。

    6.2 为什么需要深入理解梯度?

    模型调试:诊断梯度消失/爆炸问题。

    定制优化策略:调整学习率或设计新型优化器。

    理论扎实性:避免“黑箱”操作,掌握模型底层行为。

    希望这篇博客能帮助你彻底理解Softmax、交叉熵及其梯度推导的核心原理。无论是手写推导还是代码实践,这些知识都是构建强大分类模型的基石。