LSTM生成文本(字符级别)

2024-06-15 16:32
文章标签 lstm 文本 字符 级别 生成

本文主要是介绍LSTM生成文本(字符级别),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

20200817 -

引言

在网上看到过一些利用深度学习来生成文本的文章,不管生成宋词也好,生成小说也好,各种各样,都是利用深度学习的模型来生成新的东西。之前的时候,我也一直觉得,他们这种生成方式,应该就是记忆性的东西,他并没有真正的从语义的角度上理解这个文章。当然,我自己也是才疏学浅,本身就不是专门搞这种东西的人。
本篇文章中,记录一下我在网上看到的一篇利用LSTM生成文本的文章。需要注意的几个点是
1)训练过程中,输入的是什么
2)根据输出,预测的又是什么
3)最后输出的内容是否可读,又是否有意义,是否有意义是否只能从人的角度来检测

LSTM生成文本

本篇文章主要参考了另一篇文章[1],主要记录一下对数据的处理过程。

问题描述(文本生成)

利用深度学习模型生成文本,是通过已有的文本作为训练集,然后生成新的文本。
但是从我阅读完整个文章来看,他就是学习了训练文本中的一些模式,比如他文章也提到,多少个字符之后就该换行了,然后这个文章就换行了。

数据预处理流程

文章[1]中采用的数据源是《爱丽丝梦游仙境》的文章,同时是针对字符级别来进行预测。

数据输入与输出

在文章[1]中,并没有利用词嵌入的方式来将字符进行向量化,而是统计了全部的字符之后,全部按照ASCII码的数值来统计,还包含了一些特殊字符,比如"\n",","等。从这种处理的方式来看,它就是制造了一种方式,**通过输入训练字符,然后输出字符的形式来生成完整的文本。**那么,它输入的是多大,输出的又是多长的字符呢。下面来具体介绍。

在文章[1]中,它采用的方式是,**定义一个滑动窗口,滑动窗口在整个本文上一直滑动,然后输出是滑动窗口文本的下一个字符。**类似时序数据预测一样的流程,可以从它预处理数据的代码来看。

# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):seq_in = raw_text[i:i + seq_length]seq_out = raw_text[i + seq_length]dataX.append([char_to_int[char] for char in seq_in])dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns

从代码中可以看到,其采用的方式就是滑动窗口一直滑动到倒数第一个字符,每次选取这些字符的后一个字符作为后续预测的结果。

数据输出的过程

关于具体到底是怎么训练模型的,这里就不不多说了,因为他预测的是一个字符,需要一个多类别的交叉熵作为损失函数,同时将结果进行one-hot编码。下面重点来说一下他们生成文本的过程。

从前文的理解中可以发现,每次输入是一个固定长度的滑动窗口大小的字符串,然后输出一个字符作为预测结果。从模型的结构来说,如果是这种角度的话,那么你的输出必然也是固定长度的内容(当然可以通过一些技巧改变这个长度,这里暂不考虑)。

既然如此,就需要一个种子(模型需要长度的字符串)来驱动模型来进行持续生成,下面来看一下代码。

# pick a random seed
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print "Seed:"
print "\"", ''.join([int_to_char[value] for value in pattern]), "\""
# generate characters
for i in range(1000):x = numpy.reshape(pattern, (1, len(pattern), 1))x = x / float(n_vocab)prediction = model.predict(x, verbose=0)index = numpy.argmax(prediction)result = int_to_char[index]seq_in = [int_to_char[value] for value in pattern]sys.stdout.write(result)pattern.append(index)pattern = pattern[1:len(pattern)]
print "\nDone."

上述代码的整体过程就是,每次将新预测的字符添加到尾部,然后将滑动窗口往后一位,这样就是持续生成了。
注:不过,这让我想起来之前做时序数据的东西的时候,本身你预测出来的东西可能就是错的,你还用错的东西继续来作为输入,这不是积累误差吗。当然,这只是我的理解

小节

从上述的讲解中,基本上明白了,这里的文本生成是通过一个滑动窗口的字符来预测下一个字符。在原文中,也提到了其生成的单词有些的确是没有意义的。所以,看来这里还是有待提升。
当然,这里只是记录一种思路,具体的生成过程还是需要去考虑。
完整代码:

# Load LSTM network and generate text
import sys
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
# load ascii text and covert to lowercase
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
# create mapping of unique chars to integers, and a reverse mapping
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars))
# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)
print "Total Characters: ", n_chars
print "Total Vocab: ", n_vocab
# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):seq_in = raw_text[i:i + seq_length]seq_out = raw_text[i + seq_length]dataX.append([char_to_int[char] for char in seq_in])dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns
# reshape X to be [samples, time steps, features]
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
# normalize
X = X / float(n_vocab)
# one hot encode the output variable
y = np_utils.to_categorical(dataY)
# define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
# load the network weights
filename = "weights-improvement-19-1.9435.hdf5"
model.load_weights(filename)
model.compile(loss='categorical_crossentropy', optimizer='adam')
# pick a random seed
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print "Seed:"
print "\"", ''.join([int_to_char[value] for value in pattern]), "\""
# generate characters
for i in range(1000):x = numpy.reshape(pattern, (1, len(pattern), 1))x = x / float(n_vocab)prediction = model.predict(x, verbose=0)index = numpy.argmax(prediction)result = int_to_char[index]seq_in = [int_to_char[value] for value in pattern]sys.stdout.write(result)pattern.append(index)pattern = pattern[1:len(pattern)]
print "\nDone."

它这里使用checkpoint的方法来记录损失函数最低的模型。

思考

前文已经把文章[1]的整体思路给记录下来了,但是也引发了我的思考。一直以来都有这些问题困扰着我,配合这篇文章来说一下就是,LSTM模型到底学会了什么呢?这个东西我怎么解释呢?每次看到文章总说,LSTM模型能够学到长依赖,但是这个依赖是什么呢?之前使用时序数据的时候,这个依赖可能是利用历史数据来拟合后面数据的数值关系,但是这里又是什么关系呢?这些字符我可以给他任意编码,虽然代码中进行了归一化。
所以,这个我感觉才是我应该思考的东西,这一点其实挺难懂的。

参考

[1]Text Generation With LSTM Recurrent Neural Networks in Python with Keras

这篇关于LSTM生成文本(字符级别)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1063978

相关文章

RedHat运维-Linux文本操作基础-AWK进阶

你不用整理,跟着敲一遍,有个印象,然后把它保存到本地,以后要用再去看,如果有了新东西,你自个再添加。这是我参考牛客上的shell编程专项题,只不过换成了问答的方式而已。不用背,就算是我自己亲自敲,我现在好多也记不住。 1. 输出nowcoder.txt文件第5行的内容 2. 输出nowcoder.txt文件第6行的内容 3. 输出nowcoder.txt文件第7行的内容 4. 输出nowcode

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测 目录 时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测基本介绍程序设计参考资料 基本介绍 MATLAB实现LSTM时间序列未来多步预测-递归预测。LSTM是一种含有LSTM区块(blocks)或其他的一种类神经网络,文献或其他资料中LSTM区块可能被描述成智能网络单元,因为

android 带与不带logo的二维码生成

该代码基于ZXing项目,这个网上能下载得到。 定义的控件以及属性: public static final int SCAN_CODE = 1;private ImageView iv;private EditText et;private Button qr_btn,add_logo;private Bitmap logo,bitmap,bmp; //logo图标private st

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

Linux文本三剑客sed

sed和awk grep就是查找文本当中的内容,最强大的功能就是使用扩展正则表达式 sed sed是一种流编辑器,一次处理一行内容。 如果只是展示,会放在缓冲区(模式空间),展示结束后,会从模式空间把结果删除 一行行处理,处理完当前行,才会处理下一行。直到文件的末尾。 sed的命令格式和操作选项: sed -e '操作符 ' -e '操作符' 文件1 文件2 -e表示可以跟多个操作

剑指offer(C++)--第一个只出现一次的字符

题目 在一个字符串(0<=字符串长度<=10000,全部由字母组成)中找到第一个只出现一次的字符,并返回它的位置, 如果没有则返回 -1(需要区分大小写). class Solution {public:int FirstNotRepeatingChar(string str) {map<char, int> mp;for(int i = 0; i < str.size(); ++i)m

FastAdmin/bootstrapTable 表格中生成的按钮设置成文字

公司有个系统后台框架用的是FastAdmin,后台表格的操作栏按钮只有图标,想要设置成文字。 查资料后发现其实很简单,主需要新增“text”属性即可,如下 buttons: [{name: 'acceptcompany',title: '复核企业',text:'复核企业',classname: 'btn btn-xs btn-primary btn-dialog',icon: 'fa fa-pe

PHP生成csv格式Excel,秒级别实现excel导出功能

防止报超内存,兼容中文,兼容科学技术法。 爽。。。。很爽。。。。 /*** 告诉浏览器下载csv文件* @param string $filename*/public static function downloadCsv($data, $filename, $encoding = 'utf-8'){header("Content-type: text/csv");header("Conten

PHP 读取或生成大的Excel

场景,在很多情况下,需要读取Excel文件。 常用的有PHPExcel包或者使用 maatwebsite/excel 包 但是使用这个包读取或生成excel,如果excel文件过大,很容易出现超内存情况。 解决方法: 上传:要求上传者使用.csv 文件上传。然后使用php自带的 fgetcsv()函数来读取文件。http://php.net/manual/zh/function.fgetc

linux匹配Nginx日志,某个字符开头和结尾的字符串

匹配 os=1 开头, &ip结尾的字符串 cat 2018-06-07.log | egrep -o ‘os=1.*.&ip’ 存入日志。然后使用submit 前面和后面的值去掉,剩下就是需要的字符串。 cat 2018-06-07.log | egrep -o ‘os=1.*.&ip’ >log.log