从tflearn Example中学习CNN(1)

2024-05-07 16:32
文章标签 学习 cnn example tflearn

本文主要是介绍从tflearn Example中学习CNN(1),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本人水平有限,难免会有错误,如果发现希望可以及时指出来,利人利己。哈哈哈~

     这是博客写的第一篇文章,主要想从tflearn的例子代码一步步理解CNN模型。这里插一句话,tflearn是tensorflow接口的更高层次的封装,与keras的区别时debug时可以看到源码,并且tflearn代码写的非常工整,适合我这样的菜鸟学习。

    现在深度学习异常火热,如果不会点深度学习,出门都不好意思和人家打招呼。

    这篇博客主要讲解下tflearn例子里的examples/images/convnet_mnist.py对于每个函数中涉及的参数我会一一的给出中文说明,下一篇主要讲解每个参数在CNN中的含义,以及系统讲解下CNN的构建过程。废话不多说,先看代码。


  

[python] view plain copy
  1. from __future__ import division, print_function, absolute_import  
  2.   
  3. import tflearn  
  4. from tflearn.layers.core import input_data, dropout, fully_connected  
  5. from tflearn.layers.conv import conv_2d, max_pool_2d  
  6. from tflearn.layers.normalization import local_response_normalization  
  7. from tflearn.layers.estimator import regression  
  8. #加载大名顶顶的mnist数据集(http://yann.lecun.com/exdb/mnist/)  
  9. import tflearn.datasets.mnist as mnist  
  10. X, Y, testX, testY = mnist.load_data(one_hot=True)  
  11. X = X.reshape([-128281])  
  12. testX = testX.reshape([-128281])  
  13.   
  14. network = input_data(shape=[None28281], name='input')  
  15. # CNN中的卷积操作,下面会有详细解释  
  16. network = conv_2d(network, 323, activation='relu', regularizer="L2")  
  17. # 最大池化操作  
  18. network = max_pool_2d(network, 2)  
  19. # 局部响应归一化操作  
  20. network = local_response_normalization(network)  
  21. network = conv_2d(network, 643, activation='relu', regularizer="L2")  
  22. network = max_pool_2d(network, 2)  
  23. network = local_response_normalization(network)  
  24. # 全连接操作  
  25. network = fully_connected(network, 128, activation='tanh')  
  26. # dropout操作  
  27. network = dropout(network, 0.8)  
  28. network = fully_connected(network, 256, activation='tanh')  
  29. network = dropout(network, 0.8)  
  30. network = fully_connected(network, 10, activation='softmax')  
  31. # 回归操作  
  32. network = regression(network, optimizer='adam', learning_rate=0.01,  
  33.                      loss='categorical_crossentropy', name='target')  
  34.   
  35. # Training  
  36. # DNN操作,构建深度神经网络  
  37. model = tflearn.DNN(network, tensorboard_verbose=0)  
  38. model.fit({'input': X}, {'target': Y}, n_epoch=20,  
  39.            validation_set=({'input': testX}, {'target': testY}),  
  40.            snapshot_step=100, show_metric=True, run_id='convnet_mnist')  


关于conv_2d函数,在源码里是可以看到总共有14个参数,分别如下:

1.incoming: 输入的张量,形式是[batch, height, width, in_channels]
2.nb_filter: filter的个数
3.filter_size: filter的尺寸,是int类型
4.strides: 卷积操作的步长,默认是[1,1,1,1]
5.padding: padding操作时标志位,"same"或者"valid",默认是“same”
6.activation: 激活函数(ps:这里需要了解的知识很多,会单独讲)
7.bias: bool量,如果True,就是使用bias
8.weights_init: 权重的初始化
9.bias_init: bias的初始化,默认是0,比如众所周知的线性函数y=wx+b,其中的w就相当于weights,b就是bias
10.regularizer: 正则项(这里需要讲解的东西非常多,会单独讲)
11.weight_decay: 权重下降的学习率
12.trainable: bool量,是否可以被训练
13.restore: bool量,训练的模型是否被保存
14.name: 卷积层的名称,默认是"Conv2D"

关于max_pool_2d函数,在源码里有5个参数,分别如下:
1.incoming ,类似于conv_2d里的incoming
2.kernel_size:池化时核的大小,相当于conv_2d时的filter的尺寸
3.strides:类似于conv_2d里的strides
4.padding:同上
5.name:同上

看了这么多参数,好像有些迷糊,我先用一张图解释下每个参数的意义。


其中的filter就是
[1 0 1 
 0 1 0
 1 0 1],size=3,由于每次移动filter都是一个格子,所以strides=1.

关于最大池化可以看看下面这张图,这里面 strides=1,kernel_size =2(就是每个颜色块的大小),图中示意的最大池化(可以提取出显著信息,比如在进行文本分析时可以提取一句话里的关键字,以及图像处理中显著颜色,纹理等),关于池化这里多说一句,有时需要平均池化,有时需要最小池化。

下面说说其中的padding操作,做图像处理的人对于这个操作应该不会陌生,说白了,就是填充。比如你对图像做卷积操作,比如你用的3×3的卷积核,在进行边上操作时,会发现卷积核已经超过原图像,这时需要把原图像进行扩大,扩大出来的就是填充,基本都填充0。
一下关于padding的操作转自:http://www.jianshu.com/p/05c4f1621c7e

1.输入W×W的矩阵,(这里讨论长宽相等情况,不相等的话,推导方法有区别),现在想象一下脑子里有一副W*W的图像

2.假定filter的大小是F×F,卷积核

3.步长stride为S

4.输出的宽高为new_w,new_h

上面已经提到padding总共有两种方式,same,valid

当取valid

            new_weight=new_height=(W-F+1)/S(结果向上取整),

此时输出的矩阵大小比输入时小(这里不讨论F=1时的情况,说到1*1的卷积核,大家可以看看GoogLeNet模型,其中用到1*1卷积核,这个用来降维的,tflearn代码里有GoogLeNet的复现)

当取same时,

            new_height = new_weight= W/S(结果向上取整)

在高度上需要pad的像素数为

            pad_needed_height = (new_height – 1)  × S + F - W

根据上式,输入矩阵上方添加的像素数为

            pad_top = pad_needed_height / 2  (结果取整)

下方添加的像素数为

pad_down = pad_needed_height - pad_top

以此类推,在宽度上需要pad的像素数和左右分别添加的像素数为

pad_needed_width = (new_width – 1)  × S + F - W

pad_left = pad_needed_width  / 2 (结果取整)

pad_right = pad_needed_width – pad_left

下面看图示以及计算过程:



下一篇将详细介绍下激活函数以及正则项这两个大部头。

参考文献:
1.https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/
2.https://github.com/tflearn/tflearn/blob/master/examples/images/convnet_mnist.py
3.http://www.jianshu.com/p/05c4f1621c7e

这篇关于从tflearn Example中学习CNN(1)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/967842

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

线性代数|机器学习-P36在图中找聚类

文章目录 1. 常见图结构2. 谱聚类 感觉后面几节课的内容跨越太大,需要补充太多的知识点,教授讲得内容跨越较大,一般一节课的内容是书本上的一章节内容,所以看视频比较吃力,需要先预习课本内容后才能够很好的理解教授讲解的知识点。 1. 常见图结构 假设我们有如下图结构: Adjacency Matrix:行和列表示的是节点的位置,A[i,j]表示的第 i 个节点和第 j 个

Node.js学习记录(二)

目录 一、express 1、初识express 2、安装express 3、创建并启动web服务器 4、监听 GET&POST 请求、响应内容给客户端 5、获取URL中携带的查询参数 6、获取URL中动态参数 7、静态资源托管 二、工具nodemon 三、express路由 1、express中路由 2、路由的匹配 3、路由模块化 4、路由模块添加前缀 四、中间件