pytorch中的dataset和dataloader

2024-03-12 05:04
文章标签 pytorch dataset dataloader

本文主要是介绍pytorch中的dataset和dataloader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  PyTorch为我们提供了Dataset和DataLoader类分别负责可被Pytorch使用的数据集的创建以及向训练传递数据的任务。一般在项目中,我们需要根据自己的数据集个性化pytorch中储存数据集的方式和数据传递的方式,需要自己重写一些子类。
  torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。数据集主要有两个功能,一个是读取本地的数据集并储存起来,另外一个是负责处理索引(index)到样本(sample)映射。
  因此,在定义自己的数据集时,一般我们需要重写init,getitem,len这三个方法,在init中,执行加载数据集和储存数据集的任务,getitem中执行根据索引(index)从数据集中取出一个样本。len方法则返回数据集中的样本数。
例子:

import torch
import torch.utils.data.dataset as Dataset#创建子类
class mydataset(Dataset.Dataset):#初始化,定义数据内容和标签def __init__(self, Data, Label):self.Data = Dataself.Label = Label#返回数据集大小def __len__(self):return len(self.Data)#得到数据内容和标签def __getitem__(self, index):data = torch.Tensor(self.Data[index])label = torch.Tensor(self.Label[index])return data, label

  在我们创建了自己的dataset后,就可以用这个dataset创建我们的dataloader了,一般来说在PyTorch项目中加载数据集的流程是这样的: 1. 创建Dateset。 2. 将Dataset作为参数传递给DataLoader,创建dataloader。 3. DataLoader迭代按照batch大小产生训练数据提供给模型。dataloader的定义为:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None)

  解释其中几个比较关键的:dataset就是我们自己重写后的数据集;batch_size是超参,即设置的batch大小;shuffle是否打乱数据;drop_last是否截断数据(如果样本数量与batchsize不能整除的情况下);num_workers是同时参与数据读取的线程数量,多线程技术可以加快数据读取,但是如果项目不大,设置为0就行,免得引入新问题;collate_fn,数据整理函数,这个方法可以被传递进去,把一个batch的数据变成你想要的样子。
  给一个自定义的collate_fn函数的例子:

def collate_fn(data):img = [i[0][0] for i in data]txt = [i[0][1] for i in data]labels = [i[1] for i in data]img = torch.stack(img).cuda()txt = torch.stack(txt).cuda()labels = torch.LongTensor(labels).cuda()return img, txt ,labels

那么我们的一般流程为:

# 创建Dateset(自定义)dataset = mydataset() 
# Dataset传递给DataLoaderdataloader = DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8,collate_fn=my_collate_fn)
# DataLoader迭代产生训练数据提供给模型for i in range(epoch):for index,(img,txt,label) in enumerate(dataloader):pass

这篇关于pytorch中的dataset和dataloader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/800212

相关文章

HTML5自定义属性对象Dataset

原文转自HTML5自定义属性对象Dataset简介 一、html5 自定义属性介绍 之前翻译的“你必须知道的28个HTML5特征、窍门和技术”一文中对于HTML5中自定义合法属性data-已经做过些介绍,就是在HTML5中我们可以使用data-前缀设置我们需要的自定义属性,来进行一些数据的存放,例如我们要在一个文字按钮上存放相对应的id: <a href="javascript:" d

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

论文精读-Supervised Raw Video Denoising with a Benchmark Dataset on Dynamic Scenes

论文精读-Supervised Raw Video Denoising with a Benchmark Dataset on Dynamic Scenes 优势 1、构建了一个用于监督原始视频去噪的基准数据集。为了多次捕捉瞬间,我们手动为对象s创建运动。在高ISO模式下捕获每一时刻的噪声帧,并通过对多个噪声帧进行平均得到相应的干净帧。 2、有效的原始视频去噪网络(RViDeNet),通过探

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

Python(TensorFlow和PyTorch)两种显微镜成像重建算法模型(显微镜学)

🎯要点 🎯受激发射损耗显微镜算法模型:🖊恢复嘈杂二维和三维图像 | 🖊模型架构:恢复上下文信息和超分辨率图像 | 🖊使用嘈杂和高信噪比的图像训练模型 | 🖊准备半合成训练集 | 🖊优化沙邦尼尔损失和边缘损失 | 🖊使用峰值信噪比、归一化均方误差和多尺度结构相似性指数量化结果 | 🎯训练荧光显微镜模型和对抗网络图形转换模型 🍪语言内容分比 🍇Python图像归一化

Pytorch环境搭建时的各种问题

1 问题 1.一直soving environment,跳不出去。网络解决方案有:配置清华源,更新conda等,没起作用。2.下载完后,有3个要done的东西,最后那个exe开头的(可能吧),总是报错。网络解决方案有:用管理员权限打开prompt等,没起作用。3.有时候配置完源,安装包的时候显示什么https之类的东西,去c盘的用户那个文件夹里找到".condarc"文件把里面的网址都改成htt

【PyTorch】使用容器(Containers)进行网络层管理(Module)

文章目录 前言一、Sequential二、ModuleList三、ModuleDict四、ParameterList & ParameterDict总结 前言 当深度学习模型逐渐变得复杂,在编写代码时便会遇到诸多麻烦,此时便需要Containers的帮助。Containers的作用是将一部分网络层模块化,从而更方便地管理和调用。本文介绍PyTorch库常用的nn.Sequen

【python pytorch】Pytorch实现逻辑回归

pytorch 逻辑回归学习demo: import torchimport torch.nn as nnimport torchvision.datasets as dsetsimport torchvision.transforms as transformsfrom torch.autograd import Variable# Hyper Parameters input_si