深度学习之生成唐诗案例(Pytorch版)

2023-11-21 12:20

本文主要是介绍深度学习之生成唐诗案例(Pytorch版),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open("poems.txt", "r", encoding="utf-8") as fr:lines = fr.read().strip().split("\n")tangshis = []for line in lines:splits = line.split(":")if len(splits) != 2:continuetangshis.append("S" + splits[1] + "E")word2idx = {"S": 0, "E": 1}word2idx_count = 2tangshi_ids = []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] = word2idx_countword2idx_count += 1idx2word = {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids = tangshi_ids# 语料长度self.num_chars = num_chars# 词的数量self.word_count = len(self.tangshi_ids)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start = min(max(idx, 0), self.word_count - self.num_chars - 2)x = self.tangshi_ids[start: start + self.num_chars]y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset = TangShiDataset(tangshi_ids, 8)x, y = dataset[0]print(x, y)if __name__ == '__main__':# deal_tangshi()__test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd = nn.Embedding(vocab_size, 128)# 循环网络层self.rnn = nn.RNN(128, 128, 1)# 输出层self.out = nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed = self.ebd(inputs)# 正则化层embed = F.dropout(embed, p=0.2)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed = F.dropout(output, p=0.2)output = self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128)

 main.py:

import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train():dataset = TangShiDataset(tangshi_ids, 128)epochs = 100model = TangShiRNN(word2idx_count).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)for idx in range(epochs):dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)start_time = time.time()total_loss = 0total_num = 0total_correct = 0total_correct_num = 0hidden = model.init_hidden()for x, y in tqdm(dataloader):x = x.to(device)y = y.to(device)# 隐藏状态hidden = model.init_hidden()hidden = hidden.to(device)# 模型计算output, hidden = model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss = criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss += loss.sum().item()total_num += len(y)total_correct_num += y.shape[0] * y.shape[1]# print(output.shape)total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %(idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")if __name__ == '__main__':train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def predict():model = TangShiRNN(word2idx_count)model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))model.eval()hidden = torch.zeros(1, 1, 128)start_word = input("输入第一个字:")flag = Nonetangshi_strs = []while True:if not flag:outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)tangshi_strs.append("S")flag = Trueelse:tangshi_strs.append(start_word)outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)top_i = torch.argmax(outputs, dim=-1)if top_i.item() == word2idx["E"]:breakprint(top_i)start_word = idx2word[top_i.item()]print(tangshi_strs)if __name__ == '__main__':predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

这篇关于深度学习之生成唐诗案例(Pytorch版)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/402341

相关文章

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

浅析如何使用Swagger生成带权限控制的API文档

《浅析如何使用Swagger生成带权限控制的API文档》当涉及到权限控制时,如何生成既安全又详细的API文档就成了一个关键问题,所以这篇文章小编就来和大家好好聊聊如何用Swagger来生成带有... 目录准备工作配置 Swagger权限控制给 API 加上权限注解查看文档注意事项在咱们的开发工作里,API

使用Navicat工具比对两个数据库所有表结构的差异案例详解

《使用Navicat工具比对两个数据库所有表结构的差异案例详解》:本文主要介绍如何使用Navicat工具对比两个数据库test_old和test_new,并生成相应的DDLSQL语句,以便将te... 目录概要案例一、如图两个数据库test_old和test_new进行比较:二、开始比较总结概要公司存在多

Java使用POI-TL和JFreeChart动态生成Word报告

《Java使用POI-TL和JFreeChart动态生成Word报告》本文介绍了使用POI-TL和JFreeChart生成包含动态数据和图表的Word报告的方法,并分享了实际开发中的踩坑经验,通过代码... 目录前言一、需求背景二、方案分析三、 POI-TL + JFreeChart 实现3.1 Maven

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操

SpringBoot实现动态插拔的AOP的完整案例

《SpringBoot实现动态插拔的AOP的完整案例》在现代软件开发中,面向切面编程(AOP)是一种非常重要的技术,能够有效实现日志记录、安全控制、性能监控等横切关注点的分离,在传统的AOP实现中,切... 目录引言一、AOP 概述1.1 什么是 AOP1.2 AOP 的典型应用场景1.3 为什么需要动态插

MybatisGenerator文件生成不出对应文件的问题

《MybatisGenerator文件生成不出对应文件的问题》本文介绍了使用MybatisGenerator生成文件时遇到的问题及解决方法,主要步骤包括检查目标表是否存在、是否能连接到数据库、配置生成... 目录MyBATisGenerator 文件生成不出对应文件先在项目结构里引入“targetProje

Golang操作DuckDB实战案例分享

《Golang操作DuckDB实战案例分享》DuckDB是一个嵌入式SQL数据库引擎,它与众所周知的SQLite非常相似,但它是为olap风格的工作负载设计的,DuckDB支持各种数据类型和SQL特性... 目录DuckDB的主要优点环境准备初始化表和数据查询单行或多行错误处理和事务完整代码最后总结Duck

Python使用qrcode库实现生成二维码的操作指南

《Python使用qrcode库实现生成二维码的操作指南》二维码是一种广泛使用的二维条码,因其高效的数据存储能力和易于扫描的特点,广泛应用于支付、身份验证、营销推广等领域,Pythonqrcode库是... 目录一、安装 python qrcode 库二、基本使用方法1. 生成简单二维码2. 生成带 Log