Tensorflow lstm实现的小说撰写预测

2024-09-08 02:18

本文主要是介绍Tensorflow lstm实现的小说撰写预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

最近,在研究深度学习方面的知识,结合Tensorflow,完成了基于lstm的小说预测程序demo。

lstm是改进的RNN,具有长期记忆功能,相对于RNN,增加了多个门来控制输入与输出。原理方面的知识网上很多,在此,我只是将我短暂学习的tensorflow写一个预测小说的demo,如果有错误,还望大家指出。

1、将小说进行分词,去除空格,建立词汇表与id的字典,生成初始输入模型的x与y

def readfile(file_path):
    f = codecs.open(file_path, 'r', 'utf-8')
    alltext = f.read()
    alltext = re.sub(r'\s','', alltext)
    seglist = list(jieba.cut(alltext, cut_all = False))
    return seglist
    
def _build_vocab(filename):
    data = readfile(filename)
    counter = collections.Counter(data)
    count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))


    words, _ = list(zip(*count_pairs))
    word_to_id = dict(zip(words, range(len(words))))
    id_to_word = dict(zip(range(len(words)),words))
    dataids = []
    for w in data:
        dataids.append(word_to_id[w])
    return word_to_id, id_to_word,dataids


def dataproducer(batch_size, num_steps):
    word_to_id, id_to_word, data = _build_vocab('F:\\ml\\code\\lstm\\1.txt')
    datalen = len(data)
    batchlen = datalen//batch_size
    epcho_size = (batchlen - 1)//num_steps


    data = tf.reshape(data[0: batchlen*batch_size], [batch_size,batchlen])
    i = tf.train.range_input_producer(epcho_size, shuffle=False).dequeue()
    x = tf.slice(data, [0,i*num_steps],[batch_size, num_steps])
    y = tf.slice(data, [0,i*num_steps+1],[batch_size, num_steps])
    x.set_shape([batch_size, num_steps])
    y.set_shape([batch_size, num_steps])
    return x,y,id_to_word

2、建立lstm模型:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias = 0.5)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell], num_layers)

3、根据训练数据输出误差反向调整模型

with tf.variable_scope("Model", reuse = None, initializer = initializer):#tensorflow主要通过变量空间来实现共享变量
    with tf.variable_scope("r", reuse = None, initializer = initializer):
        softmax_w = tf.get_variable('softmax_w', [size, vocab_size])
        softmax_b = tf.get_variable('softmax_b', [vocab_size])
    with tf.variable_scope("RNN", reuse = None, initializer = initializer):
        for time_step in range(num_steps):
            if time_step > 0: tf.get_variable_scope().reuse_variables()
            (cell_output, state) = cell(inputs[:, time_step, :], state,)
            outputs.append(cell_output)
            
        output = tf.reshape(outputs, [-1,size])
        
        logits = tf.matmul(output, softmax_w) + softmax_b
        loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [tf.reshape(targets,[-1])], [tf.ones([batch_size*num_steps])])
        
        global_step = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
        10.0, global_step, 5000, 0.1, staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        gradients, v = zip(*optimizer.compute_gradients(loss))
        gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
        optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)

4、预测新一轮输出

teststate = test_initial_state
        (celloutput,teststate)= cell(test_inputs, teststate)
        partial_logits = tf.matmul(celloutput, softmax_w) + softmax_b
        partial_logits = tf.nn.softmax(partial_logits)

5、根据之前建立的操作,运行tensorflow会话

sv = tf.train.Supervisor(logdir=None)
with sv.managed_session() as session:
    costs = 0
    iters = 0
    for i in range(1000):
        _,l= session.run([optimizer, cost])
        costs += l
        iters +=num_steps
        perplextity = np.exp(costs / iters)
        if i%20 == 0:
            print(perplextity)
        if i%100 == 0:
            p = random_distribution()
            b = sample(p)
            sentence = id_to_word[b[0]]
            for j in range(200):
                test_output = session.run(partial_logits, feed_dict={test_input:b})
                b = sample(test_output)
                sentence += id_to_word[b[0]]
            print(sentence)    

其中,使用sv.managed_session()后,在此会话间,将不能修改graph。如果采用普通的session,程序将会阻塞于session.run(),对于这个问题,我还是很疑惑,希望理解的人帮忙解答下。

代码地址位于https://github.com/summersunshine1/datamining/tree/master/lstm,运行时只需将readdata中文件路径修改即可。作为深度学习的入门小白,希望大家多多指点。

运行结果如下:



这篇关于Tensorflow lstm实现的小说撰写预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1146833

相关文章

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义

opencv图像处理之指纹验证的实现

《opencv图像处理之指纹验证的实现》本文主要介绍了opencv图像处理之指纹验证的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、简介二、具体案例实现1. 图像显示函数2. 指纹验证函数3. 主函数4、运行结果三、总结一、

Springboot处理跨域的实现方式(附Demo)

《Springboot处理跨域的实现方式(附Demo)》:本文主要介绍Springboot处理跨域的实现方式(附Demo),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录Springboot处理跨域的方式1. 基本知识2. @CrossOrigin3. 全局跨域设置4.

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

基于SpringBoot实现文件秒传功能

《基于SpringBoot实现文件秒传功能》在开发Web应用时,文件上传是一个常见需求,然而,当用户需要上传大文件或相同文件多次时,会造成带宽浪费和服务器存储冗余,此时可以使用文件秒传技术通过识别重复... 目录前言文件秒传原理代码实现1. 创建项目基础结构2. 创建上传存储代码3. 创建Result类4.

SpringBoot日志配置SLF4J和Logback的方法实现

《SpringBoot日志配置SLF4J和Logback的方法实现》日志记录是不可或缺的一部分,本文主要介绍了SpringBoot日志配置SLF4J和Logback的方法实现,文中通过示例代码介绍的非... 目录一、前言二、案例一:初识日志三、案例二:使用Lombok输出日志四、案例三:配置Logback一

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

Python+PyQt5实现多屏幕协同播放功能

《Python+PyQt5实现多屏幕协同播放功能》在现代会议展示、数字广告、展览展示等场景中,多屏幕协同播放已成为刚需,下面我们就来看看如何利用Python和PyQt5开发一套功能强大的跨屏播控系统吧... 目录一、项目概述:突破传统播放限制二、核心技术解析2.1 多屏管理机制2.2 播放引擎设计2.3 专

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很