诗词生成--pytorch(1,代码)

2024-03-11 22:04
文章标签 代码 生成 pytorch 诗词

本文主要是介绍诗词生成--pytorch(1,代码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

首先上代码:

settings.py


# 禁用词,包含如下字符的唐诗将被忽略
DISALLOWED_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
# 句子最大长度
MAX_LEN = 64
# 最小词频
MIN_WORD_FREQUENCY = 8
# 训练的batch size
BATCH_SIZE = 16
# 数据集路径
DATASET_PATH = './poetry.txt'
# 每个epoch训练完成后,随机生成SHOW_NUM首古诗作为展示
SHOW_NUM = 5
# 共训练多少个epoch
TRAIN_EPOCHS = 20
# 最佳权重保存路径
BEST_MODEL_PATH = './best_model.h5'

dataset.py


from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settingsclass Tokenizer:"""分词器"""def __init__(self, token_dict):# 词->编号的映射self.token_dict = token_dict# 编号->词的映射self.token_dict_rev = {value: key for key, value in self.token_dict.items()}# 词汇表大小self.vocab_size = len(self.token_dict)def id_to_token(self, token_id):"""给定一个编号,查找词汇表中对应的词:param token_id: 带查找词的编号:return: 编号对应的词"""return self.token_dict_rev[token_id]def token_to_id(self, token):"""给定一个词,查找它在词汇表中的编号未找到则返回低频词[UNK]的编号:param token: 带查找编号的词:return: 词的编号"""return self.token_dict.get(token, self.token_dict['[UNK]'])def encode(self, tokens):"""给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列:param tokens: 待编码字符串:return: 编号序列"""# 加上开始标记token_ids = [self.token_to_id('[CLS]'), ]# 加入字符串编号序列for token in tokens:token_ids.append(self.token_to_id(token))# 加上结束标记token_ids.append(self.token_to_id('[SEP]'))return token_idsdef decode(self, token_ids):"""给定一个编号序列,将它解码成字符串:param token_ids: 待解码的编号序列:return: 解码出的字符串"""# 起止标记字符特殊处理spec_tokens = {'[CLS]', '[SEP]'}# 保存解码出的字符的listtokens = []for token_id in token_ids:token = self.id_to_token(token_id)if token in spec_tokens:continuetokens.append(token)# 拼接字符串return ''.join(tokens)# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:lines = f.readlines()# 将冒号统一成相同格式lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:# 有且只能有一个冒号用来分割标题if line.count(':') != 1:continue# 后半部分不能包含禁止词__, last_part = line.split(':')ignore_flag = Falsefor dis_word in disallowed_words:if dis_word in last_part:ignore_flag = Truebreakif ignore_flag:continue# 长度不能超过最大长度if len(last_part) > max_len - 2:continuepoetry.append(last_part.replace('\n', ''))# 统计词频
counter = Counter()
for line in poetry:counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)class PoetryDataGenerator:"""古诗数据集生成器"""def __init__(self, data, random=False):# 数据集self.data = data# batch sizeself.batch_size = batch_size# 每个epoch迭代的步数self.steps = int(math.floor(len(self.data) / self.batch_size))# 每个epoch开始时是否随机混洗self.random = randomdef sequence_padding(self, data, length=None, padding=None):"""将给定数据填充到相同长度:param data: 待填充数据:param length: 填充后的长度,不传递此参数则使用data中的最大长度:param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号:return: 填充后的数据"""# 计算填充长度if length is None:length = max(map(len, data))# 计算填充数据if padding is None:padding = tokenizer.token_to_id('[PAD]')# 开始填充outputs = []for line in data:padding_length = length - len(line)# 不足就进行填充if padding_length > 0:outputs.append(np.concatenate([line, [padding] * padding_length]))# 超过就进行截断else:outputs.append(line[:length])return np.array(outputs)def __len__(self):return self.stepsdef __iter__(self):total = len(self.data)# 是否随机混洗if self.random:np.random.shuffle(self.data)# 迭代一个epoch,每次yield一个batchfor start in range(0, total, self.batch_size):end = min(start + self.batch_size, total)batch_data = []# 逐一对古诗进行编码for single_data in self.data[start:end]:batch_data.append(tokenizer.encode(single_data))# 填充为相同长度batch_data = self.sequence_padding(batch_data)# yield x,yyield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)del batch_datadef for_fit(self):"""创建一个生成器,用于训练"""# 死循环,当数据训练一个epoch之后,重新迭代数据while True:# 委托生成器yield from self.__iter__()

utils.py


import numpy as np
import settingsdef generate_random_poetry(tokenizer, model, s=''):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: 一个字符串,表示一首古诗"""# 将初始字符串转成tokentoken_ids = tokenizer.encode(s)# 去掉结束标记[SEP]token_ids = token_ids[:-1]while len(token_ids) < settings.MAX_LEN:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# print(_probas)# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)if target == 3:breakreturn tokenizer.decode(token_ids)def generate_acrostic(tokenizer, model, head):"""随机生成一首藏头诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param head: 藏头诗的头:return: 一个字符串,表示一首古诗"""# 使用空串初始化token_ids,加入[CLS]token_ids = tokenizer.encode('')token_ids = token_ids[:-1]# 标点符号,这里简单的只把逗号和句号作为标点punctuations = [',', '。']punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}# 缓存生成的诗的listpoetry = []# 对于藏头诗中的每一个字,都生成一个短句for ch in head:# 先记录下这个字poetry.append(ch)# 将藏头诗的字符转成token idtoken_id = tokenizer.token_to_id(ch)# 加入到列表中去token_ids.append(token_id)# 开始生成一个短句while True:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)# 只有不是特殊字符时,才保存到poetry里面去if target > 3:poetry.append(tokenizer.id_to_token(target))if target in punctuation_ids:breakreturn ''.join(poetry)

model.py

import tensorflow as tf
from dataset import tokenizer# 构建模型
model = tf.keras.Sequential([# 不定长度的输入tf.keras.layers.Input((None,)),# 词嵌入层tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),# 第一个LSTM层,返回序列作为下一层的输入tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),# 第二个LSTM层,返回序列作为下一层的输入tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),# 对每一个时间点的输出都做softmax,预测下一个词的概率tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])# 查看模型结构
model.summary()
# 配置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

train.py


import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utilsclass Evaluate(tf.keras.callbacks.Callback):"""在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示"""def __init__(self):super().__init__()# 给loss赋一个较大的初始值self.lowest = 1e10def on_epoch_end(self, epoch, logs=None):# 在每个epoch训练完成后调用# 如果当前loss更低,就保存当前模型参数if logs['loss'] <= self.lowest:self.lowest = logs['loss']model.save(settings.BEST_MODEL_PATH)# 随机生成几首古体诗测试,查看训练效果print()for i in range(settings.SHOW_NUM):print(utils.generate_random_poetry(tokenizer, model))# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,callbacks=[Evaluate()])

eval.py


import tensorflow as tf
from dataset import tokenizer
import settings
import utils# 加载训练好的模型
model = tf.keras.models.load_model(settings.BEST_MODEL_PATH)
# 随机生成一首诗
print(utils.generate_random_poetry(tokenizer, model))
# 给出部分信息的情况下,随机生成剩余部分
print(utils.generate_random_poetry(tokenizer, model, s='床前明月光,'))
# 生成藏头诗
print(utils.generate_acrostic(tokenizer, model, head='海阔天空'))

这篇关于诗词生成--pytorch(1,代码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/799210

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

Python中顺序结构和循环结构示例代码

《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(

浅析如何使用Swagger生成带权限控制的API文档

《浅析如何使用Swagger生成带权限控制的API文档》当涉及到权限控制时,如何生成既安全又详细的API文档就成了一个关键问题,所以这篇文章小编就来和大家好好聊聊如何用Swagger来生成带有... 目录准备工作配置 Swagger权限控制给 API 加上权限注解查看文档注意事项在咱们的开发工作里,API

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

CSS3中使用flex和grid实现等高元素布局的示例代码

《CSS3中使用flex和grid实现等高元素布局的示例代码》:本文主要介绍了使用CSS3中的Flexbox和Grid布局实现等高元素布局的方法,通过简单的两列实现、每行放置3列以及全部代码的展示,展示了这两种布局方式的实现细节和效果,详细内容请阅读本文,希望能对你有所帮助... 过往的实现方法是使用浮动加

JAVA调用Deepseek的api完成基本对话简单代码示例

《JAVA调用Deepseek的api完成基本对话简单代码示例》:本文主要介绍JAVA调用Deepseek的api完成基本对话的相关资料,文中详细讲解了如何获取DeepSeekAPI密钥、添加H... 获取API密钥首先,从DeepSeek平台获取API密钥,用于身份验证。添加HTTP客户端依赖使用Jav

Java实现状态模式的示例代码

《Java实现状态模式的示例代码》状态模式是一种行为型设计模式,允许对象根据其内部状态改变行为,本文主要介绍了Java实现状态模式的示例代码,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来... 目录一、简介1、定义2、状态模式的结构二、Java实现案例1、电灯开关状态案例2、番茄工作法状态案例

Java使用POI-TL和JFreeChart动态生成Word报告

《Java使用POI-TL和JFreeChart动态生成Word报告》本文介绍了使用POI-TL和JFreeChart生成包含动态数据和图表的Word报告的方法,并分享了实际开发中的踩坑经验,通过代码... 目录前言一、需求背景二、方案分析三、 POI-TL + JFreeChart 实现3.1 Maven

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.