论文辅助笔记:TEMPO 之 dataset.py

2024-05-03 06:04

本文主要是介绍论文辅助笔记:TEMPO 之 dataset.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

0 导入库

import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from .utils import StandardScaler, decompose
from .features import time_features

1 Dataset_ETT_hour

1.1 构造函数

class Dataset_ETT_hour(Dataset):def __init__(self,root_path,flag="train",size=None,features="S",data_path="ETTh1.csv",target="OT",scale=True,inverse=False,timeenc=0,freq="h",cols=None,period=24,):if size == None:self.seq_len = 24 * 4 * 4self.pred_len = 24 * 4else:self.seq_len = size[0]self.pred_len = size[1]#输入sequence和输出sequence的长度assert flag in ["train", "test", "val"]type_map = {"train": 0, "val": 1, "test": 2}self.set_type = type_map[flag]'''指定数据集的用途,可以是 "train"、"test" 或 "val",分别对应训练集、测试集和验证集'''self.features = features#指定数据集包含的特征类型,默认为 "S",表示单一特征self.target = target#指定预测的目标特征self.scale = scale#一个布尔值,用于确定数据是否需要归一化处理self.inverse = inverse#一个布尔值,用于决定是否进行逆变换self.timeenc = timeenc#用于确定是否对时间进行编码【原始模样 or -0.5~0.5区间】self.freq = freq#定义时间序列的频率,如 "h" 表示小时级别的频率self.period = period#定义时间序列的周期,默认为 24self.root_path = root_pathself.data_path = data_pathself.__read_data__()#用于读取并初始化数据集

1.2 __read_data__

def __read_data__(self):self.scaler = StandardScaler()#初始化一个 StandardScaler 对象,用于数据的标准化处理df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))#读取数据集文件,将其存储为 DataFrame 对象 df_rawborder1s = [0,12 * 30 * 24 - self.seq_len,12 * 30 * 24 + 4 * 30 * 24 - self.seq_len,]#定义了三个区间的起始位置,分别对应训练集、验证集和测试集border2s = [12 * 30 * 24,12 * 30 * 24 + 4 * 30 * 24,12 * 30 * 24 + 8 * 30 * 24,]#定义了每个区间的结束位置border1 = border1s[self.set_type]border2 = border2s[self.set_type]'''通过 self.set_type 确定当前数据集类型并从 border1s 和 border2s 中获取对应的起始和结束位置 border1 和 border2'''if self.features == "M" or self.features == "MS":cols_data = df_raw.columns[1:]df_data = df_raw[cols_data]elif self.features == "S":df_data = df_raw[[self.target]]'''选择特征数据:多特征 "M" 或 "MS":选择所有数据列,除去日期列。单一特征 "S":只选择目标特征列(由 self.target 指定)。'''if self.scale:train_data = df_data[border1s[0] : border2s[0]]self.scaler.fit(train_data.values)data = self.scaler.transform(df_data.values)else:data = df_data.values'''如果 self.scale 为 True,则执行数据归一化:train_data:选择训练集的数据,用于拟合 self.scaler。data:对整个 df_data 进行转换。'''df_stamp = df_raw[["date"]][border1:border2]df_stamp["date"] = pd.to_datetime(df_stamp.date)data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)'''时间特征处理:提取日期列 df_stamp,并将其转换为时间特征:pd.to_datetime:将日期转换为 datetime 对象。time_features:用于生成时间特征。'''self.data_x = data[border1:border2]if self.inverse:self.data_y = df_data.values[border1:border2]else:self.data_y = data[border1:border2]self.data_stamp = data_stamp'''将转换后的数据和时间特征赋值给 self.data_x、self.data_y 和 self.data_stamp:self.data_x 取 data 中的对应区间数据。self.data_y 根据 self.inverse 决定是从 data 还是 df_data 中获取。self.data_stamp 取生成的时间特征。'''

1.3 __getitem__

def __getitem__(self, index):s_begin = index#设置序列的起始点s_end = s_begin + self.seq_len#计算序列的结束点r_begin = s_end#设置预测序列的起始点r_end = r_begin + self.pred_len#计算预测序列的结束点seq_x = self.data_x[s_begin:s_end]#从 data_x 中提取序列部分seq_y = self.data_y[r_begin:r_end]# 从 data_y 中提取预测部分[ground-truth]x = torch.tensor(seq_x, dtype=torch.float).transpose(1, 0)  # [1, seq_len]y = torch.tensor(seq_y, dtype=torch.float).transpose(1, 0)  # [1, pred_len](trend, seasonal, residual) = decompose(x, period=self.period)#对序列 x 进行时间序列分解,返回趋势、季节性和残差三部分components = torch.cat((trend, seasonal, residual), dim=0)  # [3, seq_len]#将分解后的三部分按 0 维(纵向)拼接,形成一个包含三种特征的张量return components, y

1.3__len__

    def __len__(self):return len(self.data_x) - self.seq_len - self.pred_len + 1

1.4  inverse_transform

将数据进行逆转换,还原到原始尺度

    def inverse_transform(self, data):return self.scaler.inverse_transform(data)

2 Dataset_ETT_minute

基本上和hour 的一样,几个地方不一样:

  • __init__
    • data_path="ETTm1.csv",
    • freq="t",
    • period: int = 60,
  • __read_data__
    • border1s = [0,12 * 30 * 24 * 4 - self.seq_len,12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len,]
      border2s = [12 * 30 * 24 * 4,12 * 30 * 24 * 4 + 4 * 30 * 24 * 4,12 * 30 * 24 * 4 + 8 * 30 * 24 * 4,]

这篇关于论文辅助笔记:TEMPO 之 dataset.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/955989

相关文章

Python调用另一个py文件并传递参数常见的方法及其应用场景

《Python调用另一个py文件并传递参数常见的方法及其应用场景》:本文主要介绍在Python中调用另一个py文件并传递参数的几种常见方法,包括使用import语句、exec函数、subproce... 目录前言1. 使用import语句1.1 基本用法1.2 导入特定函数1.3 处理文件路径2. 使用ex

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个

查看提交历史 —— Git 学习笔记 11

查看提交历史 查看提交历史 不带任何选项的git log-p选项--stat 选项--pretty=oneline选项--pretty=format选项git log常用选项列表参考资料 在提交了若干更新,又或者克隆了某个项目之后,你也许想回顾下提交历史。 完成这个任务最简单而又有效的 工具是 git log 命令。 接下来的例子会用一个用于演示的 simplegit