本文主要是介绍Pytorch之Dataset和DataLoader的注意事项,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1、数据集的保存形式:一行一行的。
比如说预测两个值的加法:a+b=c,那么传进Dataset的形式应该是
a1,b1,c1
a2,b2,c2
...
an,bn,cn
2、代码
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset# 创建数据
data_rand = np.random.rand(10, 2)
datas = np.insert(data_rand, 2, data_rand.sum(axis=1), axis=1)
print("\ndatas.shape=", datas.shape)
print("datas=\n", datas)train_data = datas[:int(len(datas) * 0.9)]
test_data = datas[int(len(datas) * 0.9):]debug_flag = False # False,Trueclass PreDataSet(Dataset):def __init__(self, _data):self.x_data = torch.Tensor(_data[:, :-1])self.y_data = torch.Tensor(_data[:, -1])if debug_flag:print(">>self.x_data.shape=", self.x_data.shape)print(">>self.y_data.shape=", self.y_data.shape)self.n_getitem = 0 # 记录进入__getitem__的次数self.n_len = 0 # 记录进入__len__的次数def __getitem__(self, index):self.n_getitem = self.n_getitem + 1if debug_flag:print(">>index=", index, "n_getitem=", self.n_getitem)print(">>x_data[index].shape=", self.x_data[index].shape)print(">>y_data[index].shape=", self.y_data[index].shape)return self.x_data[index], self.y_data[index]def __len__(self):self.n_len = self.n_len + 1if debug_flag:print(">>len(self.x_data)=", len(self.x_data), "n_len=", self.n_len)return len(self.x_data)train_dataset = PreDataSet(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)# 2、输出看结果
for x, y in train_dataloader:print("\nx=", x)print("y=", y)if debug_flag:print("x.shape=", x.shape)print("y.shape=", y.shape)
参考B站视频
【2、数据集加载(Dataset和DataLoader)】
这篇关于Pytorch之Dataset和DataLoader的注意事项的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!