白话RNN系列(五)

2024-09-06 05:38
文章标签 系列 rnn 白话

本文主要是介绍白话RNN系列(五),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前文,对于LSTM的结构进行了系统的介绍,本文,通过一个MNIST_data的例子,争取能够把LSTM的基本使用来吃透。

import tensorflow as tf
import input_data
# 导入 MINST 数据集
# from tensorflow.examples.tutorials.mnist import input_data
# one_hot=True,代表输入的是one-hot向量
mnist = input_data.read_data_sets("/MNIST_data/", one_hot=True)
# 每次输入的是一个28维的向量
n_input = 28  # MNIST data 输入 (img shape: 28*28)
# LSTM的循环体会循环28次
n_steps = 28  # timesteps
# 隐藏层神经元的数目是128
n_hidden = 128  # hidden layer num of features
n_classes = 10  # MNIST 列别 (0-9 ,一共10类)

这里需要注意,因为MNIST_data的输入实际上是一个784维的向量,来自于一张28*28大小的手写字体图片,我们设定n_steps为循环次数,而n_input代表每次会输入28个特征,也就是一行的数据。

必须要注意这两个28的区别。

# 定义初始输入的占位符
# None 其实代表的是一个批次会输入多少个数据
# n_steps 即上面的28 ,代表会循环输入28次
# n_input 值为28 ,代表每个循环内会输入28 * 1维度的向量
x = tf.placeholder("float", [None, n_steps, n_input])
# y是最终的分类结果,n_classes = 10 ,代表有10个分类
y = tf.placeholder("float", [None, n_classes])

这里,定义我们的初步输入x和y,但实际上,我们真正输入的内容还不是x和y,而是下面的x1:

# 上文,得到我们输入的x 是28 *28 维度的矩阵
# 这里,通过unstack,拆分成28 列,每一列都是28 * 1维度的列向量
x1 = tf.unstack(x, axis=1)

对于这里,我的理解是,x是28 * 28的矩阵,每一行对应图片的一行,每一列对应于图片中的一列,通过unstack方式,切分得到的每个list,实际上内部是图片的一列元素(这在当前行=列的情况下无影响),但具体使用的时候,还是需要严格区分两个28的分别作用。

到这里,我们得到了每次输入的列向量,是28 * 1维度的列向量。

# 通过tensorflow初始化一个LSTMCell
# n_hidden 代表隐藏层的节点数
# forget_bias 即遗忘门中的偏置,默认为1,可以不写
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden, forget_bias=1.0, name='basic_lstm_cell')
# 使用静态运行的方式得到我们的输出和隐藏状态
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)

LSTMCell的初始化过程中,指定隐藏层神经元的数目即可;此处,我们以静态运行的状态启动,输入的x1为28 * 1维度的列向量。

# 加入一个全连接
# 得到n_classes=10 的输出用于结果分类
pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)

全连接,得到10 * 1的列向量,用于输出分类结果。

接下来,我们看下训练过程:

step = 1# Keep training until reach max iterations# batch_size 为128 ,代表每次输入128张图片进行训练# training_iters 为100000 ,定义了我们的训练次数while step * batch_size < training_iters:# 至于图片是怎么读取的,关系应该不大# 生成的其实是128 * 784的部分,对于x来说# 对于batch_Y来说,得到的是128 * 1的向量batch_x, batch_y = mnist.train.next_batch(batch_size)# 为什么会是128 * 784的数据# Reshape data to get 28 seq of 28 elements# 数据解析为128个,后面是28 * 28维的向量batch_x = batch_x.reshape((batch_size, n_steps, n_input))# Run optimization op (backprop)sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % display_step == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print(" Finished!")

里面有几个变量,我们没有定义,在这里补充下:

learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 10# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

学习率,迭代次数,批次大小,display_step用于为我们确定多少步输出一次精确度:

OK,到这里我们把代码重新梳理一遍:

我们每个批次输入了batch_size = 128 的数据量,而每个数据,实际上是28 * 28 大小的矩阵,通过unstack方式,将其转化为28 个list,每个list再进行reshape,转化为28 * 1的列向量,最终作为LSTM的输入:

接着,我们搭建了num_units=128 的LSTMCell,其输出后面紧跟着一个全连接神经网络,得到最后的输出;再经过一轮softmax,得到最后的输出结果。

简化来说,我们只要能够确定清楚自己输入向量的维度,很轻松就能够使用LSTM完成我们的任务。

全部代码在此处;

# -*- coding: utf-8 -*-
import tensorflow as tfimport input_data# 导入 MINST 数据集
# from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/MNIST_data/", one_hot=True)
n_input = 28  # MNIST data 输入 (img shape: 28*28)
n_steps = 28  # timesteps
n_hidden = 128  # hidden layer num of features
n_classes = 10  # MNIST 列别 (0-9 ,一共10类)# 这里对于上面都要整体说明下
# 这里的输入当成28个时间段,每段内容都是28个值,使用unstack将原始的输入28 * 28 调整成具有28个元素的list
# 每个元素为1 * 28的数组,这28个时序一次送入RNN中
# 由于是批次操作,所以每次都会取该批次中所有图片的一行作为一个时间序列输入tf.reset_default_graph()# tf Graph input
# 确认人家的输入和输出时怎么得到的
# 这里输入的好像是批次个数:然后step应该是max_times,即最多有多少个时序
# 每个时序输入的其实是28个
# 构建出来的n_steps=28
# n_input = 28x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])# 1真的是按照列切分的
# 的确是按照列切分的
# x本身是28 * 28的,现在切割成28个列,每列都是28个元素,算是一个样本
x1 = tf.unstack(x, axis=1)# 1 BasicLSTMCell
# 先构建一个包含128个cell的类lstm_cell, 然后将变形后的x1放进去生成节点outputs
# 最后通过全连接生成pred, 最后使用softmax进行分类
# 这128个cell是怎么定义出来的?lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden, forget_bias=1.0, name='basic_lstm_cell')
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)# 通过全连接形成的十个分类,即10 * 1的向量
# 然后通过softmax方式,与真正的输出进行交叉熵合作
pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 10# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# 启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer())step = 1# Keep training until reach max iterations# 这里的batch_size = 128 ,代表每次会输出128个数据进行训练# training_iters为100000while step * batch_size < training_iters:# 这里的数据,应该是横向的,一行是一个图片的数据# 至于图片是怎么读取的,关系应该不大# 生成的其实是128 * 784的部分,对于x来说# 对于batch_Y来说,得到的是128 * 1的向量batch_x, batch_y = mnist.train.next_batch(batch_size)# 为什么会是128 * 784的数据# Reshape data to get 28 seq of 28 elements# 看来理解还是不透彻,如何这么实现reshape的?# 自行解析成三维的数据# 后面的数据会自行解析成二维的部分# 数据解析为128个,后面是28 * 28维的向量batch_x = batch_x.reshape((batch_size, n_steps, n_input))# Run optimization op (backprop)out = sess.run(outputs, feed_dict={x: batch_x, y: batch_y})print(out[-1].shape)sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % display_step == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print(" Finished!")# 计算准确率 for 128 mnist test imagestest_len = 128test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print("Testing Accuracy:", \sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

 

这篇关于白话RNN系列(五)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141156

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

flume系列之:查看flume系统日志、查看统计flume日志类型、查看flume日志

遍历指定目录下多个文件查找指定内容 服务器系统日志会记录flume相关日志 cat /var/log/messages |grep -i oom 查找系统日志中关于flume的指定日志 import osdef search_string_in_files(directory, search_string):count = 0

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训

Java基础回顾系列-第七天-高级编程之IO

Java基础回顾系列-第七天-高级编程之IO 文件操作字节流与字符流OutputStream字节输出流FileOutputStream InputStream字节输入流FileInputStream Writer字符输出流FileWriter Reader字符输入流字节流与字符流的区别转换流InputStreamReaderOutputStreamWriter 文件复制 字符编码内存操作流(

Java基础回顾系列-第五天-高级编程之API类库

Java基础回顾系列-第五天-高级编程之API类库 Java基础类库StringBufferStringBuilderStringCharSequence接口AutoCloseable接口RuntimeSystemCleaner对象克隆 数字操作类Math数学计算类Random随机数生成类BigInteger/BigDecimal大数字操作类 日期操作类DateSimpleDateForma

Java基础回顾系列-第三天-Lambda表达式

Java基础回顾系列-第三天-Lambda表达式 Lambda表达式方法引用引用静态方法引用实例化对象的方法引用特定类型的方法引用构造方法 内建函数式接口Function基础接口DoubleToIntFunction 类型转换接口Consumer消费型函数式接口Supplier供给型函数式接口Predicate断言型函数式接口 Stream API 该篇博文需重点了解:内建函数式

Java基础回顾系列-第二天-面向对象编程

面向对象编程 Java类核心开发结构面向对象封装继承多态 抽象类abstract接口interface抽象类与接口的区别深入分析类与对象内存分析 继承extends重写(Override)与重载(Overload)重写(Override)重载(Overload)重写与重载之间的区别总结 this关键字static关键字static变量static方法static代码块 代码块String类特

Java基础回顾系列-第六天-Java集合

Java基础回顾系列-第六天-Java集合 集合概述数组的弊端集合框架的优点Java集合关系图集合框架体系图java.util.Collection接口 List集合java.util.List接口java.util.ArrayListjava.util.LinkedListjava.util.Vector Set集合java.util.Set接口java.util.HashSetjava