本文主要是介绍分类问题为什么用交叉熵损失不用 MSE 损失,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
-
本文说明以下问题
- MSE 损失主要适用与回归问题,因为优化 MSE 等价于对高斯分布模型做极大似然估计,而简单回归中做服从高斯分布的假设是比较合理的
- 交叉熵损失主要适用于多分类问题,因为优化交叉熵损失等价于对多项式分布模型做极大似然估计,而多分类问题通常服从多项式分布
事实上,最大似然估计往往将损失建模为负对数似然,这样的损失一定等价于定义在训练集上的经验分布和定义在模型上的概率分布间的交叉熵,这个交叉熵根据模型定义有时可以转化为不同的损失,这块可以参考:信息论概念详细梳理:信息量、信息熵、条件熵、互信息、交叉熵、KL散度、JS散度 4.1 节
-
先明确本文讨论的多分类问题的符号:
- 训练样本集大小为 N N N
- 类别数为 K K K
- 第 i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的真实标记概率分布为 y ( i ) = { y 1 ( i ) , y 2 ( i ) , . . . , y K ( i ) } \pmb{y}^{(i)}=\{y_1^{(i)},y_2^{(i)},...,y_K^{(i)}\} y(i)={y1(i),y2(i),...,yK(i)},事实上这是一个 one-hot 向量
- 第 i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的预测标记概率分布为 y ^ ( i ) = { y ^ 1 ( i ) , y ^ 2 ( i ) , . . . , y ^ K ( i ) } \pmb{\hat{y}}^{(i)}=\{\hat{y}_1^{(i)},\hat{y}_2^{(i)},...,\hat{y}_K^{(i)}\} y^(i)={y^1(i),y^2(i),...,y^K(i)}
这种情况下,MSE 损失和交叉熵损失分别为
- MSE 损失: L = 1 N ∑ i N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 = 1 N ∑ i = 1 N ∑ k = 1 K ( y k ( i ) − y ^ k ( i ) ) 2 L = \frac{1}{N}\sum_{i}^N ||\pmb{y}^{(i)}-\pmb{\hat{y}}^{(i)}||^2 = \frac{1}{N}\sum_{i=1}^N\sum_{k=1}^K(y_k^{(i)}-\hat{y}_k^{(i)})^2 L=N1∑iN∣∣y(i)−y^(i)∣∣2=N1∑i=1N∑k=1K(yk(i)−y^k(i))2
- 交叉熵损失: L = − 1 N ∑ i = 1 N ∑ k = 1 K y k ( i ) l o g y ^ k ( i ) L=-\frac{1}{N}\sum_{i=1}^N\sum_{k=1}^Ky_k^{(i)}log\hat{y}_k^{(i)} L=−N1∑i=1N∑k=1Kyk(i)logy^k(i)
文章目录
- 1. 概率角度
- 1.1 优化 MSE 损失等价于高斯分布的最大似然估计
- 1.2 优化交叉熵损失等价于多项式分布的最大似然
- 1.2.1 多项式分布
- 1.2.2 优化交叉熵损失等价于多项式分布的最大似然
- 2. 梯度角度
- 3. 直观角度
1. 概率角度
1.1 优化 MSE 损失等价于高斯分布的最大似然估计
-
我们可以把第 i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的真实标记值 y ( i ) \pmb{y}^{(i)} y(i) 看做预测标记值 y ^ ( i ) \hat{\pmb{y}}^{(i)} y^(i) 加上噪音误差 e ( i ) \pmb{e}^{(i)} e(i) 所得,假设误差 e ( i ) ∼ N ( 0 , B ) \pmb{e}^{(i)}\sim N(0,\pmb{B}) e(i)∼N(0,B) 服从期望为 0,协方差矩阵为 B \pmb{B} B 的 K K K 维高斯分布,则样本的真实标签 y ( i ) = y ^ ( i ) + e ( i ) = f ( x ( i ) , w ) + e ( i ) ∼ N ( f ( x ( i ) , w ) , B ) \pmb{y}^{(i)} = \hat{\pmb{y}}^{(i)}+\pmb{e}^{(i)} = f(\pmb{x}^{(i)},\pmb{w})+\pmb{e}^{(i)} \sim N(f(\pmb{x}^{(i)},\pmb{w}),\pmb{B}) y(i)=y^(i)+e(i)=f(x(i),w)+e(i)∼N(f(x(i),w),B) 也服从期望为 y ^ ( i ) = f ( x ( i ) , w ) \pmb{\hat{y}}^{(i)} = f(\pmb{x}^{(i)},\pmb{w}) y^(i)=f(x(i),w),协方差矩阵为 B \pmb{B} B 的 K K K 维高斯分布,有
p ( y ( i ) ∣ x ( i ) , w ) = 1 ( 2 π ) n / 2 ∣ B ∣ 1 / 2 e − 1 2 △ ( i ) p(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \frac{1}{(2\pi)^{n/2}|\pmb{B}|^{1/2}}e^{-\frac{1}{2}\triangle^{(i)}} p(y(i)∣x(i),w)=(2π)n/2∣B∣1/21e−21△(i) 其中 △ ( i ) = [ y ( i ) − y ^ ( i ) ] ⊤ B − 1 [ y ( i ) − y ^ ( i ) ] \triangle^{(i)} = [\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}]^\top\pmb{B}^{-1}[\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}] △(i)=[y(i)−y^(i)]⊤B−1[y(i)−y^(i)] -
由于样本独立同分布,整个样本集的似然函数为
L ( w ) = p ( y ( 1 ) ∣ x ( 1 ) , w ) p ( y ( 2 ) ∣ x ( 2 ) , w ) . . . . p ( y ( N ) ∣ x ( N ) , w ) = ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) L(\pmb{w}) =p(\pmb{y}^{(1)}|\pmb{x}^{(1)},\pmb{w})p(\pmb{y}^{(2)}|\pmb{x}^{(2)},\pmb{w})....p(\pmb{y}^{(N)}|\pmb{x}^{(N)},\pmb{w}) = \prod\limits_{i=1}^Np(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) L(w)=p(y(1)∣x(1),w)p(y(2)∣x(2),w)....p(y(N)∣x(N),w)=i=1∏Np(y(i)∣x(i),w) 通过最大化对数似然函数的方式得到参数 w \pmb{w} w 的估计值,即
w ^ = arg max w ^ L ( w ) = arg max w ^ ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) = arg max w ^ ∑ i = 1 N log p ( y ( i ) ∣ x ( i ) , w ) = arg max w ^ ∑ i = 1 N ( log e − 1 2 △ ( i ) ) = arg max w ^ ∑ i = 1 N − △ ( i ) = arg min w ^ ∑ i = 1 N △ ( i ) = arg min w ^ ∑ i = 1 N [ y ( i ) − y ^ ( i ) ] ⊤ B − 1 [ y ( i ) − y ^ ( i ) ] \begin{aligned} \hat{\pmb{w}} &= \argmax\limits_{\mathbf{\hat{w}}} L(\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\prod_{i=1}^Np(y^{(i)}|\pmb{x}^{(i)},\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N\log p(y^{(i)}|\pmb{x}^{(i)},\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N (\log e^{-\frac{1}{2}\triangle^{(i)}})\\ &= \argmax\limits_{\mathbf{\hat{w}}} \sum_{i=1}^N-\triangle^{(i)} \\ &= \argmin\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N \triangle^{(i)} \\ &= \argmin\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N [\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}]^\top\pmb{B}^{-1}[\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}] \end{aligned} w^=w^argmaxL(w)=w^argmaxi=1∏Np(y(i)∣x(i),w)=w^argmaxi=1∑Nlogp(y(i)∣x(i),w)=w^argmaxi=1∑N(loge−21△(i))=w^argmaxi=1∑N−△(i)=w^argmini=1∑N△(i)=w^argmini=1∑N[y(i)−y^(i)]⊤B−1[y(i)−y^(i)] -
这时考虑特殊情况
- 当 K K K 维正态分布的各个维度相互独立且各项同性时,协方差矩阵变为单位矩阵,有
△ ( i ) = ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 \triangle^{(i)} = ||\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}||^2 △(i)=∣∣y(i)−y^(i)∣∣2 这时最大似然估计的结果为
w ^ = arg min w ^ ∑ i = 1 N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 \hat{\pmb{w}} = \argmin\limits_{\mathbf{\hat{w}}} \sum_{i=1}^N||\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}||^2 w^=w^argmini=1∑N∣∣y(i)−y^(i)∣∣2 这和最小化 MSE 损失的优化目标一致 - 进一步特殊化,当 K = 1 K=1 K=1 时, B = 1 \pmb{B}=1 B=1,退化到一元线性回归情况
- 当 K K K 维正态分布的各个维度相互独立且各项同性时,协方差矩阵变为单位矩阵,有
-
我们认为,误差是由于随机的、无数的、独立的、多个因素造成的,因此根据中心极限定理,预测误差在大样本量的情况下确实服从正态分布,因此优化 MSE 损失等价于对高斯分布做最大似然估计
-
注意到,只有当样本标记服从多个维度相互独立且各项同性的多维高斯分布时,优化 MSE 损失才等价于做最大似然估计,而这往往很难达成,这也是 MSE 不适用于多分类问题而适用于一元线性回归的原因之一(后者自动满足这个条件)
1.2 优化交叉熵损失等价于多项式分布的最大似然
1.2.1 多项式分布
- 我们借助从伯努利分布到二项分布的变形,从 categorical 分布导出多项式分布
- 先看熟悉的伯努利分布:抛硬币正面朝上的概率为 θ \theta θ,这时抛 1 次硬币,出现正面次数 X = m ∈ { 0 , 1 } X=m\in\{0,1\} X=m∈{0,1} 的概率服从伯努利分布
P ( X = m ∣ θ ) = θ m ( 1 − θ ) 1 − m , m ∈ { 0 , 1 } P(X=m|\theta) = \theta^m(1-\theta)^{1-m},\space m\in\{0,1\} P(X=m∣θ)=θm(1−θ)1−m, m∈{0,1} 再看二项分布:抛硬币正面朝上概率为 θ \theta θ ,抛 n n n 次硬币,出现正面次数 X = m X=m X=m 的概率服从二项分布
P ( X = m ∣ θ , n ) = C n m θ m ( 1 − θ ) n − m P(X=m|\theta,n) = C_n^m\theta^m(1-\theta)^{n-m} P(X=m∣θ,n)=Cnmθm(1−θ)n−m - categorical 分布可以类比伯努利分布: K K K 面骰子,每一面出现概率分别为 θ 1 , θ 2 , . . . , θ K \theta_1,\theta_2,...,\theta_K θ1,θ2,...,θK,抛 1 次骰子,第 p p p 面出现次数 X = m p ∈ { 0 , 1 } X=m_p\in\{0,1\} X=mp∈{0,1} 的概率服从 categorical 分布(下式 ∑ k = 1 K θ i = 1 \sum_{k=1}^K\theta_i=1 ∑k=1Kθi=1, m k ∈ { 0 , 1 } m_k\in\{0,1\} mk∈{0,1}, ∑ k = 1 K m k = 1 \sum_{k=1}^Km_k=1 ∑k=1Kmk=1)
P ( X = m p ∣ θ 1 , θ 2 , . . . , θ K ) = ∏ k = 1 K θ k m k P(X=m_p|\theta_1,\theta_2,...,\theta_K) = \prod_{k=1}^K \theta_k^{m_k} P(X=mp∣θ1,θ2,...,θK)=k=1∏Kθkmk 多项式分布可以类比二项分布: K K K 面骰子,每一面出现概率分别为 θ 1 , θ 2 , . . . , θ K \theta_1,\theta_2,...,\theta_K θ1,θ2,...,θK,抛 N N N 次骰子,第 1 面到第 k k k 面出现次数为 X 1 = m 1 , X 2 = m 2 , . . . , X K = m K X_1=m_1,X_2=m_2,...,X_K=m_K X1=m1,X2=m2,...,XK=mK 的概率服从多项分布 (下式 ∑ k = 1 K θ k = 1 \sum_{k=1}^K\theta_k=1 ∑k=1Kθk=1, ∑ k = 1 K m k = n \sum_{k=1}^Km_k=n ∑k=1Kmk=n)
P ( X 1 = m 1 , X 2 = m 2 , . . . , X K = m K ∣ θ 1 , θ 2 , . . . , θ K , N ) = C N m 1 θ 1 m 1 C N − m 1 m 2 θ 2 m 2 . . . C N − m 1 − m 2 − . . . − m k − 1 m K θ K m K = N ! m 1 ! m 2 ! . . . m K ! ∏ k = 1 K θ k m K \begin{aligned} P(X_1=m_1,X_2=m_2,...,X_K=m_K|\theta_1,\theta_2,...,\theta_K,N) &= C_N^{m_1}\theta_1^{m_1}C_{N-m_1}^{m_2}\theta_2^{m_2}...C_{N-m_1-m_2-...-m_{k-1}}^{m_K}\theta_K^{m_K}\\ &= \frac{N!}{m_1!m_2!...m_K!}\prod_{k=1}^K\theta_k^{m_K} \end{aligned} P(X1=m1,X2=m2,...,XK=mK∣θ1,θ2,...,θK,N)=CNm1θ1m1CN−m1m2θ2m2...CN−m1−m2−...−mk−1mKθKmK=m1!m2!...mK!N!k=1∏KθkmK - 对于多分类问题来说,可以把总类别数看做这里的 K K K,把各个类别的预测概率(模型输出概率)看做这里的 θ \theta θ,把总样本数看做这里的 N N N,把样本真实标记分布(one-hot 向量)看做这里的 m m m,则多分类问题的样本集可以看做服从多项式分布,上式可改写为
N ! m 1 ! m 2 ! . . . m K ! ∏ k = 1 K ( y ^ k ( i ) ) ∑ i y k ( i ) = N ! m 1 ! m 2 ! . . . m K ! ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) \begin{aligned} \frac{N!}{m_1!m_2!...m_K!}\prod_{k=1}^K(\hat{y}_k^{(i)})^{\sum_i y_k^{(i)}} = \frac{N!}{m_1!m_2!...m_K!} \prod_{i=1}^N\prod_{k=1}^K(\hat{y}_k^{(i)})^{y_k^{(i)}} \end{aligned} m1!m2!...mK!N!k=1∏K(y^k(i))∑iyk(i)=m1!m2!...mK!N!i=1∏Nk=1∏K(y^k(i))yk(i) 前面的 N ! m 1 ! m 2 ! . . . m K ! \frac{N!}{m_1!m_2!...m_K!} m1!m2!...mK!N! 是归一化系数,目的是使总和等于 1 以满足概率形式,它是个常数,并不重要 - 如果对分布的形式还不是很清楚,可以看这个例子
1.2.2 优化交叉熵损失等价于多项式分布的最大似然
- 根据 categorical 分布,第 i i i 个样本 x i \pmb{x}_i xi 真实标记 y ( i ) \pmb{y}^{(i)} y(i) 出现概率为
p ( y ( i ) ∣ x ( i ) , w ) = ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) p(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}} p(y(i)∣x(i),w)=k=1∏K(y^k(i))yk(i) 注意这里是 y ( i ) \pmb{y}^{(i)} y(i) 是 one-hot 向量, y k y_k yk 中只有一个值为 1,其他都是 0 - 似然函数为
L ( w ) = ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) = ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) L(\pmb{w}) = \prod_{i=1}^Np(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \prod_{i=1}^N\prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}} L(w)=i=1∏Np(y(i)∣x(i),w)=i=1∏Nk=1∏K(y^k(i))yk(i) 通过最大化对数似然函数的方式得到参数 w \pmb{w} w 的估计值,即
w ^ = arg max w ^ ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) = arg max w ^ ∑ i = 1 N ∑ k = 1 K y k ( i ) log y ^ k ( i ) = arg min w ^ − ∑ i = 1 N ∑ k = 1 K y k ( i ) log y ^ k ( i ) \begin{aligned} \hat{\pmb{w}} &= \argmax\limits_{\mathbf{\hat{w}}}\prod_{i=1}^N\prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}}\\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N\sum_{k=1}^K {y_k^{(i)}}\log \hat{y}_k^{(i)}\\ &= \argmin\limits_{\mathbf{\hat{w}}} - \sum_{i=1}^N\sum_{k=1}^K {y_k^{(i)}}\log \hat{y}_k^{(i)}\\ \end{aligned} w^=w^argmaxi=1∏Nk=1∏K(y^k(i))yk(i)=w^argmaxi=1∑Nk=1∑Kyk(i)logy^k(i)=w^argmin−i=1∑Nk=1∑Kyk(i)logy^k(i) 这和最小化交叉熵损失的优化目标一致
2. 梯度角度
- 从梯度角度看,如果使用了 sigmoid 或类似形状的激活函数,对于本文开头提出的多分类问题,在计算梯度时
- 用 MSE 损失,参数梯度关于绝对误差是一个凹函数形式,导致更新强度和绝对误差值不成正比,优化过程低效
- 用交叉熵损失,参数梯度关于绝对误差是线性函数形式,更新强度和绝对误差值成正比,优化过程高效稳定
- 具体推导请参考:为什么使用交叉熵作为损失函数
3. 直观角度
- 回过头看两个损失函数
- MSE 损失: L = 1 N ∑ i N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 = 1 N ∑ i = 1 N ∑ k = 1 K ( y k ( i ) − y ^ k ( i ) ) 2 L = \frac{1}{N}\sum_{i}^N ||\pmb{y}^{(i)}-\pmb{\hat{y}}^{(i)}||^2 = \frac{1}{N}\sum_{i=1}^N\sum_{k=1}^K(y_k^{(i)}-\hat{y}_k^{(i)})^2 L=N1i∑N∣∣y(i)−y^(i)∣∣2=N1i=1∑Nk=1∑K(yk(i)−y^k(i))2
- 交叉熵损失: L = − 1 N ∑ i = 1 N ∑ k = 1 K y k ( i ) l o g y ^ k ( i ) L=-\frac{1}{N}\sum_{i=1}^N\sum_{k=1}^Ky_k^{(i)}log\hat{y}_k^{(i)} L=−N1i=1∑Nk=1∑Kyk(i)logy^k(i) 由于 y ( i ) \pmb{y}^{(i)} y(i) 是 one-hot 向量,假设 k i k_i ki 是第 i i i 个样本标记类别,可以进一步化简为 L = − 1 N ∑ i = 1 N l o g y ^ k i ( i ) L=-\frac{1}{N}\sum_{i=1}^Nlog\hat{y}_{k_i}^{(i)} L=−N1i=1∑Nlogy^ki(i)
- 可见:MSE无差别地关注全部类别上预测概率和真实概率的差;交叉熵关注的是正确类别的预测概率
- 如果真实标签是 ( 1 , 0 , 0 ) (1, 0, 0) (1,0,0),模型1的预测标签是 ( 0.8 , 0.2 , 0 ) (0.8, 0.2, 0) (0.8,0.2,0),模型2的是 ( 0.8 , 0.1 , 0.1 ) (0.8, 0.1, 0.1) (0.8,0.1,0.1),那么MSE-based 认为模型2更好;交叉熵-based认为一样。从最终预测的类别上看,模型1和模型2的真实输出其实是一样的
- 再换个角度,MSE对残差大的样例惩罚更大些。比如真实标签分别是 ( 1 , 0 , 0 ) (1, 0, 0) (1,0,0),模型1的预测标签是 ( 0.8 , 0.2 , 0 ) (0.8, 0.2, 0) (0.8,0.2,0),模型2的是 ( 0.9 , 0.1 , 0 ) (0.9, 0.1, 0) (0.9,0.1,0),即使输出的标签都是类别0, 但 MSE-based 算出来模型1的误差是模型2的4倍;而交叉熵-based算出来模型1的误差是模型2的2倍左右。为了弥补模型1在这个样例上的损失,MSE-based需要3个完美预测的样例才能达到和模型2一样的损失,而交叉熵-based只需要一个。实际上,模型输出正确的类别,0.8可能已经是个不错的概率了.
- 本段参考:MSE vs 交叉熵
这篇关于分类问题为什么用交叉熵损失不用 MSE 损失的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!