tensorflow之MNIST手写字符集训练可视化

2024-06-08 22:38

本文主要是介绍tensorflow之MNIST手写字符集训练可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

很多人认为卷积神经是一个黑箱子,把图片输入,输出结果为有监督式的学习(supervised learning),贴标签的形式,即可达到分类的效果。那么计算机到底做了什么事情呢?训练过程结果如何可视化?下面进行简单的介绍。

模型的搭建

@author XT
#第1层convolutional
W1 = tf.Variable(tf.truncated_normal([5,5,1,K],stddev=0.1),dtype=tf.float32,name='W1') #[filterheight,filterwith,input_channel,output_channel]
b1 = tf.Variable(tf.ones([K])/10,dtype=tf.float32,name='b1')
#第2层convolutional
W2 = tf.Variable(tf.truncated_normal([4,4,K,L],stddev=0.1),dtype=tf.float32,name='W2')
b2 = tf.Variable(tf.ones([L])/10,dtype=tf.float32,name='b2')
#第3层convolutional
W3 = tf.Variable(tf.truncated_normal([4,4,L,M],stddev=0.1),dtype=tf.float32,name='W3')
b3 = tf.Variable(tf.ones([M])/10,dtype=tf.float32,name='b3')
#convolutional out fully connected layer
W4 = tf.Variable(tf.truncated_normal([7*7*M,N],stddev=0.1),dtype=tf.float32,name='W4')
b4 = tf.Variable(tf.ones([N])/10,dtype=tf.float32,name='b4')
#output
W5 = tf.Variable(tf.truncated_normal([N,n_class],stddev=0.1),dtype=tf.float32,name='W5')#要随最后层修改
b5 = tf.Variable(tf.ones([n_class]),dtype=tf.float32,name='b5')

这里搭建了较简单的卷积神经网络,使用了3层卷积的权值,后加全连接层,最后是输出。结构为:
这里写图片描述

代码

@author XT
#Model
pkeep = tf.placeholder(tf.float32)
x = tf.placeholder(tf.float32, [None,784])#!!!注意图片格式大小    
x_image = tf.reshape(x,[-1,28,28,1])stride=1 # output is 28x28
Y1 = tf.nn.relu(tf.nn.conv2d(x_image,W1,strides=[1,stride,stride,1],padding='SAME')+b1)#sigmoid很差,要用relu
YF1 = tf.nn.dropout(Y1,pkeep)
h_poolYF1 = tf.nn.max_pool(YF1,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1stride=2 # output is 14x14
Y2 = tf.nn.relu(tf.nn.conv2d(h_poolYF1,W2,strides=[1,stride,stride,1],padding='SAME')+b2)#sigmoid
YF2 = tf.nn.dropout(Y2,pkeep)
h_poolYF2 = tf.nn.max_pool(YF2,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1stride = 2  # output is 7x7
Y3 = tf.nn.relu(tf.nn.conv2d(h_poolYF2,W3,strides=[1,stride,stride,1],padding='SAME')+b3)#sigmoid
YF3 = tf.nn.dropout(Y3,pkeep)
h_poolYF3 = tf.nn.max_pool(YF3,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1# reshape the output from the third convolution for the fully connected layer
YY = tf.reshape(h_poolYF3, shape=[-1, 7*7*M])Y4 = tf.nn.relu(tf.matmul(YY,W4)+b4)#sigmoid
YF4 = tf.nn.dropout(Y4,pkeep)Ylogits = tf.matmul(YF4, W5)+b5
Y = tf.nn.softmax(Ylogits)#softmax

训练结果

1、总测试

这里写图片描述

2、Test One
这里写图片描述
输出概率:
这里写图片描述
第一卷积层:
这里写图片描述

这就是图片特征被激活的结果

部分代码

def plot_images(x, labels,max_index,name):'''plot one batch sizeimages:images_batchsize,4D tensor - [batch_size, width, height, channel]label_batch: 1D tensor - [batch_size]'''i = 0for one_pic_vic in x:one_pic_arr = np.reshape(one_pic_vic,(28,28))plt.subplot(1,1,i+1)plt.axis('off')plt.title('Label: %d   Forecast: %d'%(labels[i],max_index[i]), fontsize = 14)#采用A=0标签 +'  Forecast: '+max_index[i]plt.subplots_adjust(top=0.9)plt.imshow(one_pic_arr,cmap='gray')i+=1figure_title = nameax3  = plt.subplot(1,1,1)plt.text(0.5, -0.05, figure_title,horizontalalignment='center',fontsize=20,transform = ax3.transAxes)pylab.show()def show_rich_feature(x_relu,Node):print(x_relu.shape[1],"X",x_relu.shape[2])feature_map = tf.reshape(x_relu, [x_relu.shape[1],x_relu.shape[2],Node])images = tf.image.convert_image_dtype (feature_map, dtype=tf.uint8)images = sess.run(images)plt.figure(figsize=(10, 10))#if Node > 25,plot(5,5)for i in np.arange(0, Node):plt.subplot(2, 2, i + 1)#you need to change the subplot size if you use other layerplt.axis('off')plt.imshow(images[:,:,i])plt.show()

参考

【1】Tensorflow教程-VGG论文导读+Tensorflow实现+参数微调(fine-tuning)
http://v.youku.com/v_show/id_XMjcyNzYwMjkxMg==.html?spm=a2hzp.8244740.0.0
【2】谷歌云大会教程:没有博士学位如何玩转TensorFlow和深度学习(附资源)
http://www.sohu.com/a/128686069_465975

这篇关于tensorflow之MNIST手写字符集训练可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1043509

相关文章

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录 在深度学习项目中,目标检测是一项重要的任务。本文将详细介绍如何使用Detectron2进行目标检测模型的复现训练,涵盖训练数据准备、训练命令、训练日志分析、训练指标以及训练输出目录的各个文件及其作用。特别地,我们将演示在训练过程中出现中断后,如何使用 resume 功能继续训练,并将我们复现的模型与Model Zoo中的

基于SSM+Vue+MySQL的可视化高校公寓管理系统

系统展示 管理员界面 宿管界面 学生界面 系统背景   当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化,规范化管理。这样的大环境让那些止步不前,不接受信息改革带来的信息技术的企业随时面临被淘汰,被取代的风险。所以当今,各个行业领域,不管是传统的教育行业

多云架构下大模型训练的存储稳定性探索

一、多云架构与大模型训练的融合 (一)多云架构的优势与挑战 多云架构为大模型训练带来了诸多优势。首先,资源灵活性显著提高,不同的云平台可以提供不同类型的计算资源和存储服务,满足大模型训练在不同阶段的需求。例如,某些云平台可能在 GPU 计算资源上具有优势,而另一些则在存储成本或性能上表现出色,企业可以根据实际情况进行选择和组合。其次,扩展性得以增强,当大模型的规模不断扩大时,单一云平

stl的sort和手写快排的运行效率哪个比较高?

STL的sort必然要比你自己写的快排要快,因为你自己手写一个这么复杂的sort,那就太闲了。STL的sort是尽量让复杂度维持在O(N log N)的,因此就有了各种的Hybrid sort algorithm。 题主你提到的先quicksort到一定深度之后就转为heapsort,这种是introsort。 每种STL实现使用的算法各有不同,GNU Standard C++ Lib

win10不用anaconda安装tensorflow-cpu并导入pycharm

记录一下防止忘了 一、前提:已经安装了python3.6.4,想用tensorflow的包 二、在pycharm中File-Settings-Project Interpreter点“+”号导入很慢,所以直接在cmd中使用 pip install -i https://mirrors.aliyun.com/pypi/simple tensorflow-cpu下载好,默认下载的tensorflow

JS手写实现深拷贝

手写深拷贝 一、通过JSON.stringify二、函数库lodash三、递归实现深拷贝基础递归升级版递归---解决环引用爆栈问题最终版递归---解决其余类型拷贝结果 一、通过JSON.stringify JSON.parse(JSON.stringify(obj))是比较常用的深拷贝方法之一 原理:利用JSON.stringify 将JavaScript对象序列化成为JSO