PyTorch -- RNN 快速实践

2024-06-20 02:52
文章标签 pytorch rnn 实践 快速

本文主要是介绍PyTorch -- RNN 快速实践,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachloss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())#############################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"p.grad.nomr()torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

这篇关于PyTorch -- RNN 快速实践的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1076898

相关文章

乐鑫 Matter 技术体验日|快速落地 Matter 产品,引领智能家居生态新发展

随着 Matter 协议的推广和普及,智能家居行业正迎来新的发展机遇,众多厂商纷纷投身于 Matter 产品的研发与验证。然而,开发者普遍面临技术门槛高、认证流程繁琐、生产管理复杂等诸多挑战。  乐鑫信息科技 (688018.SH) 凭借深厚的研发实力与行业洞察力,推出了全面的 Matter 解决方案,包含基于乐鑫 SoC 的 Matter 硬件平台、基于开源 ESP-Matter SDK 的一

C++必修:模版的入门到实践

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:C++学习 贝蒂的主页:Betty’s blog 1. 泛型编程 首先让我们来思考一个问题,如何实现一个交换函数? void swap(int& x, int& y){int tmp = x;x = y;y = tmp;} 相信大家很快就能写出上面这段代码,但是如果要求这个交换函数支持字符型

亮相WOT全球技术创新大会,揭秘火山引擎边缘容器技术在泛CDN场景的应用与实践

2024年6月21日-22日,51CTO“WOT全球技术创新大会2024”在北京举办。火山引擎边缘计算架构师李志明受邀参与,以“边缘容器技术在泛CDN场景的应用和实践”为主题,与多位行业资深专家,共同探讨泛CDN行业技术架构以及云原生与边缘计算的发展和展望。 火山引擎边缘计算架构师李志明表示:为更好地解决传统泛CDN类业务运行中的问题,火山引擎边缘容器团队参考行业做法,结合实践经验,打造火山

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

LVGL快速入门笔记

目录 一、基础知识 1. 基础对象(lv_obj) 2. 基础对象的大小(size) 3. 基础对象的位置(position) 3.1 直接设置方式 3.2 参照父对象对齐 3.3 获取位置 4. 基础对象的盒子模型(border-box) 5. 基础对象的样式(styles) 5.1 样式的状态和部分 5.1.1 对象可以处于以下状态States的组合: 5.1.2 对象

9 个 GraphQL 安全最佳实践

GraphQL 已被最大的平台采用 - Facebook、Twitter、Github、Pinterest、Walmart - 这些大公司不能在安全性上妥协。但是,尽管 GraphQL 可以成为您的 API 的非常安全的选项,但它并不是开箱即用的。事实恰恰相反:即使是最新手的黑客,所有大门都是敞开的。此外,GraphQL 有自己的一套注意事项,因此如果您来自 REST,您可能会错过一些重要步骤!

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

Netty ByteBuf 释放详解:内存管理与最佳实践

Netty ByteBuf 释放详解:内存管理与最佳实践 在Netty中(学习netty请参考:🔗深入浅出Netty:高性能网络应用框架的原理与实践),管理ByteBuf的内存是至关重要的(学习ByteBuf请参考:🔗Netty ByteBuf 详解:高性能数据缓冲区的全面介绍)。未能正确释放ByteBuf可能会导致内存泄漏,进而影响应用的性能和稳定性。本文将详细介绍如何正确地释放ByteB

Clickhouse 的性能优化实践总结

文章目录 前言性能优化的原则数据结构优化内存优化磁盘优化网络优化CPU优化查询优化数据迁移优化 前言 ClickHouse是一个性能很强的OLAP数据库,性能强是建立在专业运维之上的,需要专业运维人员依据不同的业务需求对ClickHouse进行有针对性的优化。同一批数据,在不同的业务下,查询性能可能出现两极分化。 性能优化的原则 在进行ClickHouse性能优化时,有几条

RabbitMQ实践——临时队列

临时队列是一种自动删除队列。当这个队列被创建后,如果没有消费者监听,则会一直存在,还可以不断向其发布消息。但是一旦的消费者开始监听,然后断开监听后,它就会被自动删除。 新建自动删除队列 我们创建一个名字叫queue.auto.delete的临时队列 绑定 我们直接使用默认交换器,所以不用创建新的交换器,也不用建立绑定关系。 实验 发布消息 我们在后台管理页面的默认交换器下向这个队列