Pytorch之Dataset和DataLoader的注意事项

2024-03-21 14:12

本文主要是介绍Pytorch之Dataset和DataLoader的注意事项,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、数据集的保存形式:一行一行的。

比如说预测两个值的加法:a+b=c,那么传进Dataset的形式应该是

a1,b1,c1

a2,b2,c2

...

an,bn,cn

2、代码

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset# 创建数据
data_rand = np.random.rand(10, 2)
datas = np.insert(data_rand, 2, data_rand.sum(axis=1), axis=1)
print("\ndatas.shape=", datas.shape)
print("datas=\n", datas)train_data = datas[:int(len(datas) * 0.9)]
test_data = datas[int(len(datas) * 0.9):]debug_flag = False  # False,Trueclass PreDataSet(Dataset):def __init__(self, _data):self.x_data = torch.Tensor(_data[:, :-1])self.y_data = torch.Tensor(_data[:, -1])if debug_flag:print(">>self.x_data.shape=", self.x_data.shape)print(">>self.y_data.shape=", self.y_data.shape)self.n_getitem = 0  # 记录进入__getitem__的次数self.n_len = 0  # 记录进入__len__的次数def __getitem__(self, index):self.n_getitem = self.n_getitem + 1if debug_flag:print(">>index=", index, "n_getitem=", self.n_getitem)print(">>x_data[index].shape=", self.x_data[index].shape)print(">>y_data[index].shape=", self.y_data[index].shape)return self.x_data[index], self.y_data[index]def __len__(self):self.n_len = self.n_len + 1if debug_flag:print(">>len(self.x_data)=", len(self.x_data), "n_len=", self.n_len)return len(self.x_data)train_dataset = PreDataSet(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)# 2、输出看结果
for x, y in train_dataloader:print("\nx=", x)print("y=", y)if debug_flag:print("x.shape=", x.shape)print("y.shape=", y.shape)

参考B站视频

【2、数据集加载(Dataset和DataLoader)】

这篇关于Pytorch之Dataset和DataLoader的注意事项的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/832856

相关文章

bytes.split的用法和注意事项

当然,我很乐意详细介绍 bytes.Split 的用法和注意事项。这个函数是 Go 标准库中 bytes 包的一个重要组成部分,用于分割字节切片。 基本用法 bytes.Split 的函数签名如下: func Split(s, sep []byte) [][]byte s 是要分割的字节切片sep 是用作分隔符的字节切片返回值是一个二维字节切片,包含分割后的结果 基本使用示例: pa

HTML5自定义属性对象Dataset

原文转自HTML5自定义属性对象Dataset简介 一、html5 自定义属性介绍 之前翻译的“你必须知道的28个HTML5特征、窍门和技术”一文中对于HTML5中自定义合法属性data-已经做过些介绍,就是在HTML5中我们可以使用data-前缀设置我们需要的自定义属性,来进行一些数据的存放,例如我们要在一个文字按钮上存放相对应的id: <a href="javascript:" d

Vue项目开发各种注意事项

1、eCharts引入方式(单页面) import * as echarts from 'echarts'Vue.prototype.$echarts = echarts 2、Sass引入 sass和node-sass 中 node-sass不要引入最新版本  引入@7.x 否则会报错 可能是语法规则改变 3、严格模式不要随意开启、将eslint文件中 extends: 中的vue去除

Exchange 服务器地址列表的配置方法与注意事项

Exchange Server 是微软推出的一款企业级邮件服务器软件,广泛应用于企业内部邮件系统的搭建与管理。配置 Exchange 服务器地址列表是其中一个关键环节。本文将详细介绍 Exchange 服务器地址列表的配置方法与注意事项,帮助系统管理员顺利完成这一任务。 内容目录 1. 引言 2. 准备工作 3. 配置地址列表 3.1 创建地址列表 3.2 使用 Exchange

API28_OKgo_get注意事项

1: implementation 'com.lzy.net:okgo:2.1.4' 2:在BaseApplication中onCreate()中初始化initOKgo() private void initOKgo() {//---------这里给出的是示例代码,告诉你可以这么传,实际使用的时候,根据需要传,不需要就不传-------------//HttpHeaders headers

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

论文精读-Supervised Raw Video Denoising with a Benchmark Dataset on Dynamic Scenes

论文精读-Supervised Raw Video Denoising with a Benchmark Dataset on Dynamic Scenes 优势 1、构建了一个用于监督原始视频去噪的基准数据集。为了多次捕捉瞬间,我们手动为对象s创建运动。在高ISO模式下捕获每一时刻的噪声帧,并通过对多个噪声帧进行平均得到相应的干净帧。 2、有效的原始视频去噪网络(RViDeNet),通过探

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl