白话RNN系列(三)

2024-09-06 05:38
文章标签 rnn 白话 系列

本文主要是介绍白话RNN系列(三),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

紧接上文,白话RNN系列(二)。

通过generateData得到我们的样本数据之后,我们开始搭建自己的RNN:

# 每个批次输入的数据,这里定义为5,即每个批次输入5个数据
batch_size = 5
# RNN中循环的次数,即时间序列的长度
# 这里取长度为15的时间序列
truncated_backprop_length = 15
# 与时间序列相对应,占位符的维度为 5 * 15
# 完全可以每个批次输入一个数据,这里前面的5是为了提高效率
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
# 同样,与x对应,输出数据的占位符维度为 5 * 15
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

这里,我们定义了四个变量:batch_size, truncated_backprop_length, batchX_placeholder, batchY_placeholder ,后两个为tensorfFlow中的占位符。

# 接着,调用unstack方法
# 定义axis=1表示按列切分,即将5*15的矩阵,切分为15个列向量
inputs_series = tf.unstack(batchX_placeholder, axis=1)
# 同样,Y也切分成15个列,每一列都是1 * 5 的列向量
labels_series = tf.unstack(batchY_placeholder, axis=1)

这里,我们用unstack方法,规范化RNN的输入和输出:

# 这里,定义初始状态为5 * 4 的矩阵
# 5 是因为我们每个批次输入5个数据
# 4 是我们隐藏层神经元的个数
init_state = tf.placeholder(tf.float32, [batch_size, state_size])
current_state = init_state
# 预测结果序列
predictions_series = []
# 损失
losses = []

current_state即隐藏层当前状态,preictions_series 为预测结果序列,losses为损失。

前面都是变量的初始化,下面的代码是训练的重头戏了

# input_series 是5 * 15 的矩阵经过unstack 得到的,目前是15个元素的list,每个元素都是5*1的列向
# 量
# 同样,labels_series 也是15个列向量的list
# zip的作用,是为了并排获取元素
# 这个循环,我们可以清楚地看到,共循环了15次,即truncated_backprop_length 的大小
for current_input, labels in zip(inputs_series, labels_series):# reshape是必要的,保证输入向量的形状current_input = tf.reshape(current_input, [batch_size, 1])# current_input 为 5 * 1# current_state 为 5 * 4# 这里axis = 1, 则是行拼接,如果axis=0,则是列拼接# 这里,进行行拼接,拼接成 5 * 5的单元# 每一行的元素中,第一个是当前输入,后四个元素,均为隐藏层状态# 我们隐藏层神经元有4个,故隐藏层状态为4# 拼接成 5 * 5的元素,而实际上则是一个五维向量input_and_state_concatenated = tf.concat([current_input, current_state], 1) # 进行全连接:5 * 5 ,# 这是增加一个全连接层,自动初始化w和b,激活函数默认为relu函数,输出个数用num_outputs来指定# 这里的num_outputs即我们的隐藏层神经元个数,为4next_state = tf.contrib.layers.fully_connected(input_and_state_concatenated, state_size, activation_fn=tf.tanh)# 可以看到,隐藏层状态的不断复用,作为下一时刻的输入current_state = next_state# 这里,next_state的维度为5 * 4# 建立一个全连接网络,权重矩阵维度为4 * 2,目的是为了实现二分类# 产出的结果是个5 * 1的向量logits = tf.contrib.layers.fully_connected(next_state, num_classes, activation_fn=None)# labels 为 5 * 1的向量# logits是 5 * 2的向量,通过交叉熵计算本次的损失loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)# 计算整体损失losses.append(loss)# 通过softmax得出此次的预测结果predictions = tf.nn.softmax(logits)# 记录本次预测结果predictions_series.append(predictions)

这里注意,我们输出的logits维度为5 * 2, 而使用到的labels其维度为5 * 1,因此我们计算损失的时候使用了sparse_soft_max_cross_entropy_with_logits 方法。

在RNN的使用过程中,因为我们不用深入了解其中涉及到的反向传播过程,因此必须要清楚我们传入的张量的维度,这样能更好的发挥RNN的作用。

接下来,我们看下训练过程:

loss_list = []
# 批次数目等于序列总长度 // 每个批次输入数据大小 // 循环次数(即序列长度)
num_batches = total_series_length // batch_size // truncated_backprop_lengthnum_epochs = 5# 我们共训练5次,每次使用的数据都不一样# 但实际上,这里的数据分布是一样的,也可以使用完全相同的数据for epoch_idx in range(num_epochs):x, y = generateData()# 初始隐藏状态为全零# 其维度为5 * 4_current_state = np.zeros((batch_size, state_size))# for batch_idx in range(num_batches):  # 50000/ 5 /15=分成多少段# 每次输入都是15个元素,但是会输入多个start_idx = batch_idx * truncated_backprop_lengthend_idx = start_idx + truncated_backprop_length# x代表每一行的部分元素,所以,实际上应该是个二维向量# 每个输入的应该都是 5 * 15的矩阵batchX = x[:, start_idx:end_idx]print(batchX.shape)# 输入的也是5 * 15的矩阵batchY = y[:, start_idx:end_idx]print(batchY.shape)# 通过feed_dict 方式喂入数据进行训练_total_loss, _train_step, _current_state, _predictions_series = sess.run([total_loss, train_step, current_state, predictions_series],feed_dict={batchX_placeholder: batchX,batchY_placeholder: batchY,init_state: _current_state})loss_list.append(_total_loss)

至此,一个RNN的整体训练过程就结束了,完整的代码粘贴如下,大家可以自行测试。

# -*- coding: utf-8 -*-
"""
Created on Sat May 13 06:24:52 2017@author: 代码医生 qq群:40016981,公众号:xiangyuejiqiren
@blog:http://blog.csdn.net/lijin6249
"""import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltnum_epochs = 5
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_lengthdef generateData():x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))#在0 和1 中选择total_series_length个数y = np.roll(x, echo_step)#向右循环移位【1111000】---【0001111】y[0:echo_step] = 0x = x.reshape((batch_size, -1))  # 5,10000y = y.reshape((batch_size, -1))return (x, y)batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])
init_state = tf.placeholder(tf.float32, [batch_size, state_size])# Unpack columns
inputs_series = tf.unstack(batchX_placeholder, axis=1)#truncated_backprop_length个序列
labels_series = tf.unstack(batchY_placeholder, axis=1)current_state = init_state
predictions_series = []
losses =[]
for current_input, labels in zip(inputs_series,labels_series):
#for current_input in inputs_series:current_input = tf.reshape(current_input, [batch_size, 1])input_and_state_concatenated = tf.concat([current_input, current_state],1)  # current_state 4  +1next_state = tf.contrib.layers.fully_connected(input_and_state_concatenated,state_size,activation_fn=tf.tanh)current_state = next_statelogits =tf.contrib.layers.fully_connected(next_state,num_classes,activation_fn=None)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits)losses.append(loss)predictions = tf.nn.softmax(logits)predictions_series.append(predictions)total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)def plot(loss_list, predictions_series, batchX, batchY):plt.subplot(2, 3, 1)plt.cla()plt.plot(loss_list)for batch_series_idx in range(batch_size):one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])plt.subplot(2, 3, batch_series_idx + 2)plt.cla()plt.axis([0, truncated_backprop_length, 0, 2])left_offset = range(truncated_backprop_length)left_offset2 = range(echo_step,truncated_backprop_length+echo_step)label1 = "past values"label2 = "True echo values" label3 = "Predictions"      plt.plot(left_offset2, batchX[batch_series_idx, :]*0.2+1.5, "o--b", label=label1)plt.plot(left_offset, batchY[batch_series_idx, :]*0.2+0.8,"x--b", label=label2)plt.plot(left_offset,  single_output_series*0.2+0.1 , "o--y", label=label3)plt.legend(loc='best')plt.draw()plt.pause(0.0001)with tf.Session() as sess:sess.run(tf.global_variables_initializer())plt.ion()plt.figure()plt.show()loss_list = []for epoch_idx in range(num_epochs):x,y = generateData()_current_state = np.zeros((batch_size, state_size))print("New data, epoch", epoch_idx)for batch_idx in range(num_batches):#50000/ 5 /15=分成多少段start_idx = batch_idx * truncated_backprop_lengthend_idx = start_idx + truncated_backprop_lengthbatchX = x[:,start_idx:end_idx]batchY = y[:,start_idx:end_idx]_total_loss, _train_step, _current_state, _predictions_series = sess.run([total_loss, train_step, current_state, predictions_series],feed_dict={batchX_placeholder:batchX,batchY_placeholder:batchY,init_state:_current_state})loss_list.append(_total_loss)if batch_idx%100 == 0:print("Step",batch_idx, "Loss", _total_loss)plot(loss_list, _predictions_series, batchX, batchY)plt.ioff()
plt.show()    

代码使用python3 编写。

这篇关于白话RNN系列(三)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141154

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

flume系列之:查看flume系统日志、查看统计flume日志类型、查看flume日志

遍历指定目录下多个文件查找指定内容 服务器系统日志会记录flume相关日志 cat /var/log/messages |grep -i oom 查找系统日志中关于flume的指定日志 import osdef search_string_in_files(directory, search_string):count = 0

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训

Java基础回顾系列-第七天-高级编程之IO

Java基础回顾系列-第七天-高级编程之IO 文件操作字节流与字符流OutputStream字节输出流FileOutputStream InputStream字节输入流FileInputStream Writer字符输出流FileWriter Reader字符输入流字节流与字符流的区别转换流InputStreamReaderOutputStreamWriter 文件复制 字符编码内存操作流(

Java基础回顾系列-第五天-高级编程之API类库

Java基础回顾系列-第五天-高级编程之API类库 Java基础类库StringBufferStringBuilderStringCharSequence接口AutoCloseable接口RuntimeSystemCleaner对象克隆 数字操作类Math数学计算类Random随机数生成类BigInteger/BigDecimal大数字操作类 日期操作类DateSimpleDateForma

Java基础回顾系列-第三天-Lambda表达式

Java基础回顾系列-第三天-Lambda表达式 Lambda表达式方法引用引用静态方法引用实例化对象的方法引用特定类型的方法引用构造方法 内建函数式接口Function基础接口DoubleToIntFunction 类型转换接口Consumer消费型函数式接口Supplier供给型函数式接口Predicate断言型函数式接口 Stream API 该篇博文需重点了解:内建函数式

Java基础回顾系列-第二天-面向对象编程

面向对象编程 Java类核心开发结构面向对象封装继承多态 抽象类abstract接口interface抽象类与接口的区别深入分析类与对象内存分析 继承extends重写(Override)与重载(Overload)重写(Override)重载(Overload)重写与重载之间的区别总结 this关键字static关键字static变量static方法static代码块 代码块String类特

Java基础回顾系列-第六天-Java集合

Java基础回顾系列-第六天-Java集合 集合概述数组的弊端集合框架的优点Java集合关系图集合框架体系图java.util.Collection接口 List集合java.util.List接口java.util.ArrayListjava.util.LinkedListjava.util.Vector Set集合java.util.Set接口java.util.HashSetjava