本文主要是介绍基于胶囊网络的Fashion-MNIST数据集的10分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
胶囊网络
原文:Dynamic Routing Between Capsules
源码:https://github.com/XifengGuo/CapsNet-Fashion-MNIST
数据集
Fashion-MNIST数据集由70000张 28 ∗ 28 28*28 28∗28大小的灰度图像组成,共有10个类别,每一类别各有7000张图像。数据集划分为两部分,即训练集和测试集。其中,训练集共有60000张图像,每个类别各有6000张;测试集共有10000张图像,每一类别各有1000张。
胶囊网络结构
网络模型
采用CapsNet网络模型,该网络由两部分组成:编码器和解码器。前3层网络为编码器,即卷积层、PrimaryCaps层和DigitCaps层;后3层网络为解码器,即三层全连接层。
编码器
编码器以 28 ∗ 28 28*28 28∗28大小的Fashion-MNIST图像作为输入,以 16 ∗ 10 16*10 16∗10大小的矩阵作为输出。
论文数据集为MNIST
卷积层
该层用于检测图像的基本特征。卷积核大小为 9 ∗ 9 9*9 9∗9,步长为1,filter数为256,激活函数为Relu。输出大小为 20 ∗ 20 ∗ 256 20*20*256 20∗20∗256。
PrimaryCaps层
该层接受卷积层检测到的基本特征,用于生成特征组合。该层共有32个PrimaryCapsules,每个PrimaryCapsules由8个卷积核为 9 ∗ 9 9*9 9∗9,步长为2的卷积组成。输出大小为 6 ∗ 6 ∗ 8 ∗ 32 6*6*8*32 6∗6∗8∗32。
DigitCaps层
该层由10个16维的DigitCapsules构成,每一个DigitCapsule对应一个类别。在DigitCapsules内部,每个输入通过 8 ∗ 16 8*16 8∗16的权重矩阵将8维输入空间映射至16维Capsules输出空间。输出大小为 16 ∗ 10 16*10 16∗10。
损失函数
L k = T k m a x ( 0 , m + − ∣ ∣ v k ∣ ∣ ) 2 + λ ( 1 − T k ) m a x ( 0 , ∣ ∣ v k ∣ ∣ − m − ) 2 L_k = T_k \, max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) \, max(0, ||v_k|| - m^-)^2 Lk=Tkmax(0,m+−∣∣vk∣∣)2+λ(1−Tk)max(0,∣∣vk∣∣−m−)2
其中,若真实标签 k k k与预测标签 k k k相同,则 T k = 1 T_k = 1 Tk=1,否则为0。 m + m^+ m+和 m − m^- m−分别为0.9和0.1。 λ = 0.5 \lambda = 0.5 λ=0.5用于确保训练中的数值稳定性。
v j = ∥ s j ∥ 2 1 + ∥ s j ∥ 2 s j ∥ s j ∥ v_j = \frac{\|s_j\|^2}{1+\|s_j\|^2}\frac{s_j}{\|s_j\|} vj=1+∥sj∥2∥sj∥2∥sj∥sj
v j v_j vj表示第 j j j个capsule输出的向量。
s j = ∑ i c i j u ^ j ∣ i s_j = \sum_i c_{ij} \hat{u}_{j|i} sj=i∑ciju^j∣i
s j s_j sj为高层capsules的输入。 c i j = e x p ( b i , j ) ∑ k e x p ( b i k ) c_{ij}=\frac{exp(b_{i,j})}{\sum_kexp(b_ik)} cij=∑kexp(bik)exp(bi,j)为耦合系数,其中 b i j = b i j + u ^ j ∣ i ⋅ v j b_{ij} = b_{ij} + \hat{u}_{j|i} \cdot v_j bij=bij+u^j∣i⋅vj,初始时 b i j = 0 b_{ij} = 0 bij=0。
u ^ j ∣ i = W i j u i \hat{u}_{j|i} = W_{ij}u_i u^j∣i=Wijui
W i j W_{ij} Wij 表示权重矩阵, u i u_i ui为低层capsules的输出, u ^ i j \hat{u}_{ij} u^ij为预测向量,可视为底层capsules的输出向量进行仿射变换。
动态路由算法
解码器
解码器由三层全连接层构成,用于重建图像,损失函数为MSE函数。训练时仅使用正确的DigitCap向量。
实现细节
初始学习率为0.001,其随迭代次数增大而衰减,batch size为100,共100个epoch。
结果
这篇关于基于胶囊网络的Fashion-MNIST数据集的10分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!