有了LSTM网络,我再也不怕老师让我写作文了

2024-04-30 22:08

本文主要是介绍有了LSTM网络,我再也不怕老师让我写作文了,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

随着深度学习的迅猛发展,人工智能的强大能力已经超出了模仿人类的简单动作,例如识别物体,如今已经能发展到自动驾驶,而且车开的比人都好的地步。目前深度学习进化出的一大功能是能够进行艺术创作,前几年google开发的DeepDream算法能够自己绘制出犹如毕加索抽象画般的艺术作品,而现在使用LSTM网络甚至可以开发出自动作曲程序,据说现在很多曲调都是由深度学习网络创作的。

很多艺术创作其实是通过序列号数据构成的,例如文章其实是一个个单词前后相邻构成,音乐是一个个音符前后相邻构成,甚至绘画也是笔触前后相邻构成,因此艺术创作从数学上看其实是时间序列数据,而LSTM忘了是最擅长处理时间序列数据的,因此只要我们训练网络识别相应艺术创作的时间序列中的数据规律,我们就可以利用网络进行相应的创作。

我们要创建的网络具有的功能是自动写作。我们把含有N个单词的句子输入网络,让网络预测第N+1个单词,然后把预测结果重新输入网络,让网络预测第N+2个单词,这种自我循环能让网络创作出跟人写出来几乎一模一样的句子。例如我们有句子"hello Tom, how are you",我们把"hello Tom, how"输入网络后网络预测下个单词是"are",然后我们继续把"hello Tom, how are"输入网络,网络预测下一个单词是"you",网络运行的基本流程如下图:

1.png

上图中数据采样很重要,通常我们会从下一个可能单词的概率分布中,选择概率最大的那个单词,但是这么做会导致生成的句子不流畅,看起来不像人写得。通用做法是在可能性最高的若干个单词集合中进行一定随机选择。例如网络预测某个词的概率是30%,那么我们引入一种随机方法,使得该词被选中的概率是30%。

我们引入的随机方法,它的随机性必须要有所控制。如果随机性为0,那么最终网络创作的句子就没有一点创意,如果随机性太高,那么得到的句子在逻辑上可能就比较离谱,因此我们要把随机性控制在某个程度。于是我们引入一个控制随机性的参数叫temperature,也就是温度的意思。

在前面章节我们多次看到,当网络要给出概率时,最后输出层时softmax,它会输出一个向量,向量中每个分量的值是0到1间的小数,所有分量加总得1.我们假设这个向量用original_distributin表示,那么我们用下面的方法引入新的随机性:

def  reweight_distribution(original_distribution, temperature=0.5):distribution = np.log(original_distribution) / temperaturedistribution = np.exp(distribution)return distribution / np.sum(distribution)

上面代码会把网络softmax层输出的结果重新打乱,打乱的程度由tenperature来控制,它的值越大,打乱的程度就越高。接下来我们做一个LSTM网络,它预测的下一个元素是字符而不是我们前面所说的单词。

深度学习网络进行文章创作时,与用于输入它的文本数据相关。如果你用莎士比亚的作品作为训练数据,网络创作的文章与莎士比亚就很像,如果我们在上面函数中引入随机性,那么网络创作结果就会有一部分像莎士比亚,有一部分又不像,而不像的那部分就是网络创作的艺术性所在,下面我们用德国超人哲学创始人尼采的文章训练网络,让我们通过深度学习再造一个新的哲学家,首先我们要加载训练数据:

import  keras
import  numpy as nppath = keras.utils.get_file('nietzche.txt', origin = 'https://s3.amazonaws.com/text-datasets/nietzsche.txt')
text = open(path).read().lower()
print('Corpus length: ', len(text))

上面代码运行后,我们会下载单词量为600893的文本数据。接着我们以60个字符为一个句子,第61个字符作为预测字符,也就是告诉网络看到这60个字符后你应该预测第61个字符,同时前后两个采样句子之间的间隔是3个字符:

maxlen = 60
step = 3
setences = []
#next_chars 对应下一个字符,以便用于训练网
next_chars = []for i in range(0, len(text) - maxlen, step):setences.append(text[i : i + maxlen])next_chars.append(text[i + maxlen])print('Number of sequentence: ', len(setences))chars = sorted(list(set(text)))
print('Unique characters: ', len(chars))
#为每个字符做编号
char_indices = dict((char, chars.index(char)) for char in chars)
print('Vectorization....')
'''
整个文本中不同字符的个数为chars, 对于当个字符我们对他进行one-hot编码,
也就是构造一个含有chars个元素的向量,根据字符对于的编号,我们把向量的对应元素设置为1,
一个句子含有60个字符,因此一行句子对应一个二维句子(maxlen, chars),矩阵的行数是maxlen,列数
是chars
'''
x = np.zeros((len(setences), maxlen, len(chars)), dtype = np.bool)
y = np.zeros((len(setences), len(chars)), dtype = np.bool)for i, setence in enumerate(setences):for t, char in enumerate(setence):x[i, t, char_indices[char]] = 1y[i, char_indices[next_chars[i]]] = 1

上面代码中构造的x就是输入数据,当输入句子是x时,我们要调教网络去预测下一个字符是y。代码先统计文本资料总共有多少个不同的字符,这些字符包含标点符号,根据运行结果显示,文本总共有57个不同字符,同时我们将不同字符进行编号。

然后构造含有57个元素的向量,当句子中某个字符出现时,我们就把向量中下标对应字符编号的元素设置为1,我们这些向量输入到网络进行训练:

from keras import layersmodel = keras.models.Sequential()
model.add(layers.LSTM(128, input_shape(maxlen, len(chars))))
model.add(layers.Dense(len(chars), activation = 'softmax'))
optimizer = leras.optimizers.RMSprop(lr = 0.01)
model.compile(loss = 'categorical_crossentropy', optimizer = optimizer)

网络输出结果对应一个含有57个元素的向量,每个元素对应相应编号的字符,元素的值表示下一个字符是对应字符的概率。我们按照前面说过的方法对网络给出的概率分布引入随机性,然后选出下一个字符,把选出的字符添加到输入句子中形成新的输入句子传入到网络,让网络以同样的方法判断下一个字符:

def  sample(preds, temperature = 1.0):preds = np.asarray(preds).astype('float64')preds = np.log(preds) / temperatureexp_preds = np.exp(preds)preds = exp_preds / np.sum(exp_preds)'''由于preds含有57个元素,每个元素表示对应字符出现的概率,我们可以把这57个元素看成一个含有57面的骰子,骰子第i面出现的概率由preds[i]决定,然后我们模拟丢一次这个57面骰子,看看出现哪一面,这一面对应的字符作为网络预测的下一个字符'''probas = np.random.multinomial(1, preds, 1)return np.argmax(probas)

接着我们启动训练流程:

import random
import sysfor epoch in range(1, 60):print('epoch:', epoch)model.fit(x, y, batch_size = 128, epochs = 1)start_index = random.randint(0, len(text) - maxlen - 1)generated_text = text[start_index: start_index + maxlen]print('---Generating with seed:"' + generated_text + '"')for temperature in [0.2, 0.5, 1.0, 1.2]:print('---temperature:', temperature)#先输出一段原文sys.stdout.write(generated_text)'''根据原文,我们让网络创作接着原文后面的400个字符组合成的段子'''for i in range(400):sampled = np.zeros((1, maxlen, len(chars)))for t, char in enumerate(generated_text):sampled[0, t, char_indices[char]] = 1.#让网络根据当前输入字符预测下一个字符preds = model.predict(sampled, verbose = 0)[0]next_index = sample(preds, temperature)next_char = chars[next_index]generated_text += next_chargenerated_text = generated_text[1:]sys.stdout.write(next_char)sys.stdout.flush()print()

上面代码将尼采的作品输入到网络进行训练,训练后网络生成的段子就会带上明显的尼采风格,代码最好通过科学上网的方式,通过谷歌的colab,运行到GPU上,如果在CPU上运行,它训练的速度会非常慢。

我们看看经过20多次循环训练后,网络生成文章的效果如下:
屏幕快照 2019-01-31 下午4.11.37.png

输出中,Generating with seed 后面的语句是我们从原文任意位置摘出的60个字符。接下来的文字是网络自动生成的段子。当temperature值越小,网络生成的段子与原文就越相似,值越大,网络生成的段子与原文差异就越大,随着epoch数量越大,也就是网络训练次数越多,它生成的段子就越通顺,而且表达的内容也越有创意。

注意到随着temperature值越大,网络合成的词语错误也越多,有些单词甚至是几个字符的随机组合。从观察上来看,temperature取值0.5的效果是最好的。

更多内容,请点击进入csdn学院

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于有了LSTM网络,我再也不怕老师让我写作文了的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950046

相关文章

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

ASIO网络调试助手之一:简介

多年前,写过几篇《Boost.Asio C++网络编程》的学习文章,一直没机会实践。最近项目中用到了Asio,于是抽空写了个网络调试助手。 开发环境: Win10 Qt5.12.6 + Asio(standalone) + spdlog 支持协议: UDP + TCP Client + TCP Server 独立的Asio(http://www.think-async.com)只包含了头文件,不依