memo: DL | Loss Functions

(2022-5-25) 损失函数是输入样本batch的函数,不同batch的误差函数不同,如果在一个batch上某 w 的导数为零,在下一个batch上该 w 的导数不为零,就可以继续修正,而不会停滞在鞍点。误差函数横轴是 w,纵轴是 error。0

3 loss functions

损失函数是为了: 衡量两个概率模型间的差别, 三种思路:最小二乘法(MSE),极大似然估计(MLE),交叉熵(CE) ¹

MSE

对于单分类问题(是or不是),也就是抛硬币(反面的概率已经蕴含在正面的概率之中了),那么:
最小二乘就是:概率(sigmoid输出0~1)减标签。为了依据误差修正w,误差取平方使其可导。如果损失函数中代入的是预测的正(反)概率,那损失函数是个二次曲线 L = (prob - label)²,横坐标是概率,纵坐标是loss,当prob=target, 预测出来的值与离散的观察值最接近。

对于输出是多维的,在各个维度上都是二次曲线,多元最小二乘(多元线性回归): J(w) = (Xw-Y)ᵀ(Xw-Y)。解∇J(w) = 0,就是求 pseudoinverse matrix

如果这样能一次求出 w,为什么还要梯度下降呢? 因为有时 X 不可逆(可以加正则化项解决 note), 而且有的激活函数 g 也不可逆(通过手动做归一化解决,例如 B-ELM 中使用了 minmax 函数)。

当MSE 用于回归问题,loss=∑ᵢ(yᵢ-wxᵢ-b)²是凸函数,直接求导等于零,即可求出解析解;但是用于分类问题,输出需要经过sigmoid/softmax变成概率,loss=∑ᵢ(yᵢ-1/(1+e⁻ʷˣⁱ))²是非凸的,不能直接求解析解,而且不宜优化 ³

MLE

极大似然估计:各个可能的假设模型产生训练样本标签的分布的概率是多少,目标就是找到概率最大时对应的模型(加个负号取最小);∏ᵢ pᵢˣ (1-pᵢ)¹⁻ˣ

CE

交叉熵:网络模型要与人脑中的模型足够接近,某一事件在网络模型中发生对应的信息量要接近在人脑中发生对应的信息量,多个事件要以他们在人脑中发生的概率加权。∑ᵢ humanᵢ(-log₂ netᵢ)

最小二乘可以用于回归,即网络输出可以是任意的数值;而极大似然估计和交叉熵都是基于概率的,网络的输出是概率,位于0-1之间,所以采用MLE或CE损失函数时,输出层神经元的激活函数需要用sigmoid,把输出压缩到0-1之间; 而隐藏层都可以用ReLu。 多类别问题输出用 softmax 激活,得到各类别的概率分布。

交叉熵认为各类别相互独立,每一维是一个二分类器,单个样本的概率(似然)是:P₁ʸ¹ ⋅ P₂ʸ² ⋅ … ⋅ Pₖʸᵏ, 所以需要用 softmax 做一下归一化


MSE 与 CE 区别

(Google search: “为什么不用mse做损失函数”)

MSE 不适合分类问题²

工程角度:如果用MSE做分类,对 softmax 的输出使用 MSE,即正确类的概率越接近 1 越好,其他类的概率越小越好: minimize Loss = (prob_true-1)² + ∑(prob_other)²。 但是在 Loss 的梯度表达式中存在 prob_true 这个因子,可能在训练初期 prob_true 很小,梯度趋于0,无法更新。 而在用 CE 做Loss时,它的梯度中不含有单独的 prob_true 这一项(被消掉了),就不易发生梯度消失²

(2022-11-06) 分类问题常使用 softmax,所以适合使用CE;而回归问题不常使用softmax,所以适合使用 MSE。

理论角度:二者假设不同,MSE假设观察到的 y’=真实y+高斯噪声,所以通过极大似然法求解一组参数使得对应的高斯噪声最小的情况。所以MSE求解出来的值会更偏向于各个离散的观察值。而CE的假设应该是多分类情况下,拟合不同类别的概率分布。“多分类问题的分布符合多项式分布,CE是多项式分布的最大似然

交叉熵不适用回归问题

在(多)分类问题中,交叉熵的损失函数只和分类正确的预测结果有关系,而MSE的损失函数还和错误的分类有关系,因此该"分类"函数除了让正确的分类尽量变大,还会让错误的分类变得平均,但实际在分类问题中,MSE 的这个调整是没有必要的

多分类问题 中的 “类别” 对应到 多元回归问题 中的 “特征”。对于一个连续的输出量,应是由各个特征共同作用的,分别有不同的贡献,而不能只看重某一个特征,所以CE不适合回归问题。但也可以用


损失函数的性质

(2023-02-17)

  1. 可微分性
  2. 可导性
  3. 单调性
  4. 凸性
  5. 可分离性
  6. 可表示性

借助 pytorch 可视化损失函数的导数9

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import torch
import matplotlib.pyplot as plt

def abs_func(x):
    return x.abs()

x = torch.linspace(-2,2,100)
x.requires_grad_(True)
y = abs_func(x)
plt.plot(x.detach().numpy(), y.detach().numpy())
plt.show()

y_prime = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
plt.plot(x.detach().numpy(), y_prime.detach().numpy())
plt.show()

Ref