只谈代码之用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)...

本文主要是介绍只谈代码之用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)...,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

就一句话。

不谈感情,只谈代码。

本系列的代码可以当作入门,复习,作为模板修改成自己的,都是可以的。

这个系列会长期更新下去,主要与python,机器学习,数据挖掘,tensorflow,pytorch相关,后期自己准备复习一些java,学习一些go相关也会一起分享。

一来是大家都代码实战的诉求其实是比理论多的多,总会有人在我文章下面问:有代码吗?能给份代码嘛?

说实话,很多代码给不了,也没法给,我不太喜欢这样白嫖,还是希望各位自己去做一些事情,这样才是对自身的要求和提高,也是一个程序员的素养。

二来,也算自己对一些基础的积累和巩固,尽量两天一篇,每天保持对代码的嗅觉。

今天讲的是pytorch框架下,写一个平时常用来测试的小例子,关于时间序列,模型用的最简单的LSTM,多步迭代预测~

前情提要:

【PyTorch修炼】三、先做减法,具体例子带你了解torch使用的基本套路(简单分类和时间序列预测小例子)

【PyTorch修炼】二、带你详细了解并使用Dataset以及DataLoader

【PyTorch修炼】三、先做减法,具体例子带你了解torch使用的基本套路(简单分类和时间序列预测小例子)

1. 导入我们需要用到的包,此教程 包含可视化以及模型训练和测试

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltimport torch
import torch.nn as nn

2. 对于demo,尝试模型,我们可以用自己模拟的数据或者公开数据集,这里为了方便,采用自己模拟sin函数,并可视化

x = torch.linspace(0, 999, 1000)
y = torch.sin(x*2*3.1415926/70)
plt.xlim(-5, 1005)
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.title("sin")
plt.plot(y.numpy(), color='#800080')
plt.show()
5ea57435f93537b8cf10934bdef760df.png

3. 分训练集和测试集,并且需要对数据进行time windows分割

# len(test):50 
train_y= y[:-70]
test_y = y[-70:]

滑窗创建数据集

def create_data_seq(seq, time_window):out = []l = len(seq)for i in range(l-time_window):x_tw = seq[i:i+time_window]y_label = seq[i+time_window:i+time_window+1]out.append((x_tw, y_label))return out
time_window = 60
train_data = create_data_seq(train_y, time_window)

4. 定义lstm模型

https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

【Deep Learning】通俗大白话详述RNN理论和LSTM理论

【Deep Learning】详细解读LSTM与GRU单元的各个公式和区别

class MyLstm(nn.Module):def __init__(self, input_size=1, hidden_size=128, out_size=1):super(MyLstm, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size=input_size, hidden_size=self.hidden_size, num_layers=1, bidirectional=False)self.linear = nn.Linear(in_features=self.hidden_size, out_features=out_size, bias=True)self.hidden_state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))def forward(self, x):out, self.hidden_state = self.lstm(x.view(len(x), 1, -1), self.hidden_state)pred = self.linear(out.view(len(x), -1))return pred[-1]

5. 训练准备工作

(1) 超参数

(2) 定义loss,优化器,实例化模型

(3) 训练模型,为了更加直观,加入对test_y的预测最终可视化

learning_rate = 0.00001
epoch = 10
multi_step = 70
model = MyLstm()
mse_loss = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.5,0.999))device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)
for i in range(epoch):for x_seq, y_label in train_data:x_seq = x_seq.to(device)y_label = y_label.to(device)model.hidden_state = (torch.zeros(1, 1, model.hidden_size).to(device), torch.zeros(1, 1, model.hidden_size).to(device))pred = model(x_seq)loss = mse_loss(y_label, pred)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {i} Loss: {loss.item()}")preds = []labels = []preds = train_y[-time_window:].tolist()for j in range(multi_step):test_seq = torch.FloatTensor(preds[-time_window:]).to(device)with torch.no_grad():model.hidden_state = (torch.zeros(1, 1, model.hidden_size).to(device), torch.zeros(1, 1, model.hidden_size).to(device))preds.append(model(test_seq).item())loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))print(f"Performance on test range: {loss}")plt.figure(figsize=(12,4))plt.xlim(700,999)plt.grid(True)plt.plot(y.numpy(),color='#8000ff')plt.plot(range(999-multi_step,999),preds[-multi_step:],color='#ff8000')plt.show()

结果

0a6e22e56e2b1d6e2dfca926e85bbc71.png 71a8cf5a501e2c36e737f188add823ba.png完整代码:https://github.com/chehongshu/AIwoniuche_Learning/blob/master/Pytorch_LSTM_examples/demo-timeseries.ipynb

这篇关于只谈代码之用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)...的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/412386

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu2289(简单二分)

虽说是简单二分,但是我还是wa死了  题意:已知圆台的体积,求高度 首先要知道圆台体积怎么求:设上下底的半径分别为r1,r2,高为h,V = PI*(r1*r1+r1*r2+r2*r2)*h/3 然后以h进行二分 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#includ

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

usaco 1.3 Prime Cryptarithm(简单哈希表暴搜剪枝)

思路: 1. 用一个 hash[ ] 数组存放输入的数字,令 hash[ tmp ]=1 。 2. 一个自定义函数 check( ) ,检查各位是否为输入的数字。 3. 暴搜。第一行数从 100到999,第二行数从 10到99。 4. 剪枝。 代码: /*ID: who jayLANG: C++TASK: crypt1*/#include<stdio.h>bool h

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n