informer辅助笔记:exp/exp_informer.py

2023-12-01 14:04
文章标签 笔记 py 辅助 exp informer

本文主要是介绍informer辅助笔记:exp/exp_informer.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

0 导入库

from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
from exp.exp_basic import Exp_Basic
from models.model import Informer, InformerStackfrom utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metricimport numpy as npimport torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoaderimport os
import timeimport warnings
warnings.filterwarnings('ignore')

1 Exp_Informer

class Exp_Informer(Exp_Basic):def __init__(self, args):super(Exp_Informer, self).__init__(args)

1.1 build_model

'''
用于构建模型。它根据提供的参数来实例化特定类型的模型
'''
def _build_model(self):model_dict = {'informer':Informer,'informerstack':InformerStack,}if self.args.model=='informer' or self.args.model=='informerstack':e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layersmodel = model_dict[self.args.model](self.args.enc_in,self.args.dec_in, self.args.c_out, self.args.seq_len, self.args.label_len,self.args.pred_len, self.args.factor,self.args.d_model, self.args.n_heads, e_layers, # self.args.e_layers,self.args.d_layers, self.args.d_ff,self.args.dropout, self.args.attn,self.args.embed,self.args.freq,self.args.activation,self.args.output_attention,self.args.distil,self.args.mix,self.device).float()#用提供的参数实例化模型if self.args.use_multi_gpu and self.args.use_gpu:model = nn.DataParallel(model, device_ids=self.args.device_ids)#如果设置为使用多 GPU,那么模型将被包装在 nn.DataParallel 中,以便在多个 GPU 上并行运行。return model

1.2 get_data

'''
根据指定的模式(如训练、测试或预测)获取数据
'''
def _get_data(self, flag):args = self.argsdata_dict = {'ETTh1':Dataset_ETT_hour,'ETTh2':Dataset_ETT_hour,'ETTm1':Dataset_ETT_minute,'ETTm2':Dataset_ETT_minute,'WTH':Dataset_Custom,'ECL':Dataset_Custom,'Solar':Dataset_Custom,'custom':Dataset_Custom,}'''定义了一个字典,映射不同的数据集名称到相应的数据集类。例如,'ETTh1' 和 'ETTh2' 映射到 Dataset_ETT_hour 类。'''Data = data_dict[self.args.data]#根据参数中指定的数据集名称选择相应的数据集类timeenc = 0 if args.embed!='timeF' else 1    #设置时间编码标志。如果嵌入类型不是 'timeF',则 timeenc 设置为 0,否则设置为 1。if flag == 'test':shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freqelif flag=='pred':shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freqData = Dataset_Predelse:shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq'''根据 flag 参数(指示数据集用途,如 'test', 'pred', 或其他)设置不同的参数:shuffle_flag:是否打乱数据。drop_last:在数据批次不足时是否丢弃最后一批数据。batch_size:每批数据的大小。freq:数据频率,用于确定数据处理的时间间隔。'''data_set = Data(root_path=args.root_path,data_path=args.data_path,flag=flag,size=[args.seq_len, args.label_len, args.pred_len],features=args.features,target=args.target,inverse=args.inverse,timeenc=timeenc,freq=freq,cols=args.cols)'''使用指定参数实例化数据集。这里包括了数据路径标志(如 'train', 'test')序列长度、标签长度、预测长度特征类型 (M,S,MS)目标列时间编码标志频率需要使用的列'''print(flag, len(data_set))data_loader = DataLoader(data_set,batch_size=batch_size,shuffle=shuffle_flag,num_workers=args.num_workers,drop_last=drop_last)'''使用 DataLoader 创建一个数据加载器,用于批量加载数据同时指定是否打乱、是否丢弃最后一个批次、使用的工作进程数量等。'''return data_set, data_loader#返回数据集和数据加载器的实例

1.3 optimizer & criterion

def _select_optimizer(self):model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)return model_optimdef _select_criterion(self):criterion =  nn.MSELoss()return criterion#选择优化器和损失函数

1.4 vali

'''
在验证集上评估模型
'''
def vali(self, vali_data, vali_loader, criterion):self.model.eval() #将模型设置为评估模式total_loss = []for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(vali_loader):#遍历验证数据加载器中的每个批次pred, true = self._process_one_batch(vali_data, batch_x, batch_y, batch_x_mark, batch_y_mark)#调用 _process_one_batch 方法处理一个批次的数据。这个方法会返回预测值(pred)和真实值(true)loss = criterion(pred.detach().cpu(), true.detach().cpu())#计算预测值和真实值之间的损失total_loss.append(loss)#将计算出的损失添加到 total_loss 列表中total_loss = np.average(total_loss)#计算所有批次损失的平均值。这个平均损失表示在验证数据集上模型的整体性能。self.model.train()#将模型重新设置为训练模式,继续训练模型return total_loss#返回计算出的平均损失值

1.5 train

'''
训练模型
'''
def train(self, setting):train_data, train_loader = self._get_data(flag = 'train')vali_data, vali_loader = self._get_data(flag = 'val')test_data, test_loader = self._get_data(flag = 'test')#使用 _get_data 方法加载训练、验证和测试数据集。path = os.path.join(self.args.checkpoints, setting)if not os.path.exists(path):os.makedirs(path)#创建用于保存模型检查点的目录time_now = time.time()train_steps = len(train_loader)early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)#使用EarlyStopping  检查是否应停止训练model_optim = self._select_optimizer()criterion =  self._select_criterion()if self.args.use_amp:scaler = torch.cuda.amp.GradScaler()'''初始化一些变量:train_steps:训练数据加载器中的批次总数。early_stopping:如果验证损失在一定迭代次数后没有改善,则停止训练。model_optim:选择优化器。criterion:选择损失函数。如果启用了自动混合精度(AMP),则初始化 scaler。'''for epoch in range(self.args.train_epochs):iter_count = 0train_loss = []self.model.train()epoch_time = time.time()for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):#遍历训练数据加载器中的所有批次iter_count += 1model_optim.zero_grad() #清除模型优化器的梯度pred, true = self._process_one_batch(train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)#使用 _process_one_batch 处理批次数据,计算损失loss = criterion(pred, true)#计算这一个batch预测值和实际值的差距train_loss.append(loss.item())if (i+1) % 100==0:print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))speed = (time.time()-time_now)/iter_countleft_time = speed*((self.args.train_epochs - epoch)*train_steps - i)print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))iter_count = 0time_now = time.time()#每100次迭代打印损失和预计剩余时间if self.args.use_amp:scaler.scale(loss).backward()scaler.step(model_optim)scaler.update()else:loss.backward()model_optim.step()#损失后向传播和优化器步骤,如果启用了 AMP,则使用 scaler 进行这些步骤print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))train_loss = np.average(train_loss)vali_loss = self.vali(vali_data, vali_loader, criterion)#对模型进行validationtest_loss = self.vali(test_data, test_loader, criterion)print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(epoch + 1, train_steps, train_loss, vali_loss, test_loss))early_stopping(vali_loss, self.model, path)if early_stopping.early_stop:print("Early stopping")breakadjust_learning_rate(model_optim, epoch+1, self.args)best_model_path = path+'/'+'checkpoint.pth'self.model.load_state_dict(torch.load(best_model_path))#在训练结束后,加载表现最好的模型状态return self.model

1.6 test

'''
在测试集上评估模型
'''
def test(self, setting):test_data, test_loader = self._get_data(flag='test')#加载测试数据集self.model.eval()preds = []trues = []#存储模型的预测和相应的真实值for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(test_loader):pred, true = self._process_one_batch(test_data, batch_x, batch_y, batch_x_mark, batch_y_mark)preds.append(pred.detach().cpu().numpy())trues.append(true.detach().cpu().numpy())'''遍历测试数据加载器中的每个批次。使用 _process_one_batch 方法处理每个批次的数据。将预测值和真实值添加到各自的列表中。'''preds = np.array(preds)trues = np.array(trues)print('test shape:', preds.shape, trues.shape)preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])print('test shape:', preds.shape, trues.shape)# result savefolder_path = './results/' + setting +'/'if not os.path.exists(folder_path):os.makedirs(folder_path)#创建一个文件夹来存储测试结果mae, mse, rmse, mape, mspe = metric(preds, trues)#使用自定义的 metric 函数计算各种性能指标,如 MAE(平均绝对误差)、MSE(均方误差)、RMSE(均方根误差)、MAPE(平均绝对百分比误差)和 MSPE(均方百分比误差)。print('mse:{}, mae:{}'.format(mse, mae))np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))np.save(folder_path+'pred.npy', preds)np.save(folder_path+'true.npy', trues)return

1.7 predict

#在新数据上进行模型预测
def predict(self, setting, load=False):pred_data, pred_loader = self._get_data(flag='pred')#加载预测数据集if load:path = os.path.join(self.args.checkpoints, setting)best_model_path = path+'/'+'checkpoint.pth'self.model.load_state_dict(torch.load(best_model_path))#如果 load 为 True,则从保存的路径加载最佳模型的状态。self.model.eval()preds = []for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(pred_loader):pred, true = self._process_one_batch(pred_data, batch_x, batch_y, batch_x_mark, batch_y_mark)preds.append(pred.detach().cpu().numpy())'''遍历预测数据加载器中的每个批次。使用 _process_one_batch 方法处理每个批次的数据。将预测值添加到 preds 列表中。'''preds = np.array(preds)preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])# result savefolder_path = './results/' + setting +'/'if not os.path.exists(folder_path):os.makedirs(folder_path)np.save(folder_path+'real_prediction.npy', preds)#保存预测结果return

1.8 process_one_batch

'''
处理一个数据批次
'''
def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):batch_x = batch_x.float().to(self.device)batch_y = batch_y.float()batch_x_mark = batch_x_mark.float().to(self.device)batch_y_mark = batch_y_mark.float().to(self.device)# decoder inputif self.args.padding==0:dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()elif self.args.padding==1:dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()#根据 self.args.padding 的值创建一个全零或全一的张量作为解码器的初始输入dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)#将这个张量与 batch_y 的一部分拼接,形成完整的解码器输入# encoder - decoderif self.args.use_amp:with torch.cuda.amp.autocast():if self.args.output_attention:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]else:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)else:if self.args.output_attention:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]else:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)if self.args.inverse:outputs = dataset_object.inverse_transform(outputs)#encoder-decoder的输出f_dim = -1 if self.args.features=='MS' else 0batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)#从 batch_y 中选择与预测长度相对应的部分,并移动到指定设备。#f_dim 变量用于确定特征维度。return outputs, batch_y

这篇关于informer辅助笔记:exp/exp_informer.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/441448

相关文章

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个

查看提交历史 —— Git 学习笔记 11

查看提交历史 查看提交历史 不带任何选项的git log-p选项--stat 选项--pretty=oneline选项--pretty=format选项git log常用选项列表参考资料 在提交了若干更新,又或者克隆了某个项目之后,你也许想回顾下提交历史。 完成这个任务最简单而又有效的 工具是 git log 命令。 接下来的例子会用一个用于演示的 simplegit

记录每次更新到仓库 —— Git 学习笔记 10

记录每次更新到仓库 文章目录 文件的状态三个区域检查当前文件状态跟踪新文件取消跟踪(un-tracking)文件重新跟踪(re-tracking)文件暂存已修改文件忽略某些文件查看已暂存和未暂存的修改提交更新跳过暂存区删除文件移动文件参考资料 咱们接着很多天以前的 取得Git仓库 这篇文章继续说。 文件的状态 不管是通过哪种方法,现在我们已经有了一个仓库,并从这个仓

忽略某些文件 —— Git 学习笔记 05

忽略某些文件 忽略某些文件 通过.gitignore文件其他规则源如何选择规则源参考资料 对于某些文件,我们不希望把它们纳入 Git 的管理,也不希望它们总出现在未跟踪文件列表。通常它们都是些自动生成的文件,比如日志文件、编译过程中创建的临时文件等。 通过.gitignore文件 假设我们要忽略 lib.a 文件,那我们可以在 lib.a 所在目录下创建一个名为 .gi

取得 Git 仓库 —— Git 学习笔记 04

取得 Git 仓库 —— Git 学习笔记 04 我认为, Git 的学习分为两大块:一是工作区、索引、本地版本库之间的交互;二是本地版本库和远程版本库之间的交互。第一块是基础,第二块是难点。 下面,我们就围绕着第一部分内容来学习,先不考虑远程仓库,只考虑本地仓库。 怎样取得项目的 Git 仓库? 有两种取得 Git 项目仓库的方法。第一种是在本地创建一个新的仓库,第二种是把其他地方的某个