基于胶囊网络的Fashion-MNIST数据集的10分类

2024-08-27 06:48

本文主要是介绍基于胶囊网络的Fashion-MNIST数据集的10分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

胶囊网络


原文:Dynamic Routing Between Capsules
源码:https://github.com/XifengGuo/CapsNet-Fashion-MNIST


数据集

Fashion-MNIST数据集由70000张 28 ∗ 28 28*28 2828大小的灰度图像组成,共有10个类别,每一类别各有7000张图像。数据集划分为两部分,即训练集和测试集。其中,训练集共有60000张图像,每个类别各有6000张;测试集共有10000张图像,每一类别各有1000张。

胶囊网络结构

网络模型

采用CapsNet网络模型,该网络由两部分组成:编码器和解码器。前3层网络为编码器,即卷积层、PrimaryCaps层和DigitCaps层;后3层网络为解码器,即三层全连接层。

编码器

编码器

编码器以 28 ∗ 28 28*28 2828大小的Fashion-MNIST图像作为输入,以 16 ∗ 10 16*10 1610大小的矩阵作为输出。

论文数据集为MNIST

卷积层

该层用于检测图像的基本特征。卷积核大小为 9 ∗ 9 9*9 99,步长为1,filter数为256,激活函数为Relu。输出大小为 20 ∗ 20 ∗ 256 20*20*256 2020256

PrimaryCaps层

该层接受卷积层检测到的基本特征,用于生成特征组合。该层共有32个PrimaryCapsules,每个PrimaryCapsules由8个卷积核为 9 ∗ 9 9*9 99,步长为2的卷积组成。输出大小为 6 ∗ 6 ∗ 8 ∗ 32 6*6*8*32 66832

DigitCaps层

该层由10个16维的DigitCapsules构成,每一个DigitCapsule对应一个类别。在DigitCapsules内部,每个输入通过 8 ∗ 16 8*16 816的权重矩阵将8维输入空间映射至16维Capsules输出空间。输出大小为 16 ∗ 10 16*10 1610

损失函数

L k = T k m a x ( 0 , m + − ∣ ∣ v k ∣ ∣ ) 2 + λ ( 1 − T k ) m a x ( 0 , ∣ ∣ v k ∣ ∣ − m − ) 2 L_k = T_k \, max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) \, max(0, ||v_k|| - m^-)^2 Lk=Tkmax(0,m+vk)2+λ(1Tk)max(0,vkm)2

其中,若真实标签 k k k与预测标签 k k k相同,则 T k = 1 T_k = 1 Tk=1,否则为0。 m + m^+ m+ m − m^- m分别为0.9和0.1。 λ = 0.5 \lambda = 0.5 λ=0.5用于确保训练中的数值稳定性。

v j = ∥ s j ∥ 2 1 + ∥ s j ∥ 2 s j ∥ s j ∥ v_j = \frac{\|s_j\|^2}{1+\|s_j\|^2}\frac{s_j}{\|s_j\|} vj=1+sj2sj2sjsj

v j v_j vj表示第 j j j个capsule输出的向量。

s j = ∑ i c i j u ^ j ∣ i s_j = \sum_i c_{ij} \hat{u}_{j|i} sj=iciju^ji

s j s_j sj为高层capsules的输入。 c i j = e x p ( b i , j ) ∑ k e x p ( b i k ) c_{ij}=\frac{exp(b_{i,j})}{\sum_kexp(b_ik)} cij=kexp(bik)exp(bi,j)为耦合系数,其中 b i j = b i j + u ^ j ∣ i ⋅ v j b_{ij} = b_{ij} + \hat{u}_{j|i} \cdot v_j bij=bij+u^jivj,初始时 b i j = 0 b_{ij} = 0 bij=0

u ^ j ∣ i = W i j u i \hat{u}_{j|i} = W_{ij}u_i u^ji=Wijui

W i j W_{ij} Wij 表示权重矩阵, u i u_i ui为低层capsules的输出, u ^ i j \hat{u}_{ij} u^ij为预测向量,可视为底层capsules的输出向量进行仿射变换。

动态路由算法

动态路由算法

解码器

解码器

解码器由三层全连接层构成,用于重建图像,损失函数为MSE函数。训练时仅使用正确的DigitCap向量。

实现细节

初始学习率为0.001,其随迭代次数增大而衰减,batch size为100,共100个epoch。

结果

![][4]

这篇关于基于胶囊网络的Fashion-MNIST数据集的10分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1110942

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

异构存储(冷热数据分离)

异构存储主要解决不同的数据,存储在不同类型的硬盘中,达到最佳性能的问题。 异构存储Shell操作 (1)查看当前有哪些存储策略可以用 [lytfly@hadoop102 hadoop-3.1.4]$ hdfs storagepolicies -listPolicies (2)为指定路径(数据存储目录)设置指定的存储策略 hdfs storagepolicies -setStoragePo

Hadoop集群数据均衡之磁盘间数据均衡

生产环境,由于硬盘空间不足,往往需要增加一块硬盘。刚加载的硬盘没有数据时,可以执行磁盘数据均衡命令。(Hadoop3.x新特性) plan后面带的节点的名字必须是已经存在的,并且是需要均衡的节点。 如果节点不存在,会报如下错误: 如果节点只有一个硬盘的话,不会创建均衡计划: (1)生成均衡计划 hdfs diskbalancer -plan hadoop102 (2)执行均衡计划 hd

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor