pytorch中的dataset和dataloader

2024-03-12 05:04
文章标签 pytorch dataset dataloader

本文主要是介绍pytorch中的dataset和dataloader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  PyTorch为我们提供了Dataset和DataLoader类分别负责可被Pytorch使用的数据集的创建以及向训练传递数据的任务。一般在项目中,我们需要根据自己的数据集个性化pytorch中储存数据集的方式和数据传递的方式,需要自己重写一些子类。
  torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。数据集主要有两个功能,一个是读取本地的数据集并储存起来,另外一个是负责处理索引(index)到样本(sample)映射。
  因此,在定义自己的数据集时,一般我们需要重写init,getitem,len这三个方法,在init中,执行加载数据集和储存数据集的任务,getitem中执行根据索引(index)从数据集中取出一个样本。len方法则返回数据集中的样本数。
例子:

import torch
import torch.utils.data.dataset as Dataset#创建子类
class mydataset(Dataset.Dataset):#初始化,定义数据内容和标签def __init__(self, Data, Label):self.Data = Dataself.Label = Label#返回数据集大小def __len__(self):return len(self.Data)#得到数据内容和标签def __getitem__(self, index):data = torch.Tensor(self.Data[index])label = torch.Tensor(self.Label[index])return data, label

  在我们创建了自己的dataset后,就可以用这个dataset创建我们的dataloader了,一般来说在PyTorch项目中加载数据集的流程是这样的: 1. 创建Dateset。 2. 将Dataset作为参数传递给DataLoader,创建dataloader。 3. DataLoader迭代按照batch大小产生训练数据提供给模型。dataloader的定义为:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None)

  解释其中几个比较关键的:dataset就是我们自己重写后的数据集;batch_size是超参,即设置的batch大小;shuffle是否打乱数据;drop_last是否截断数据(如果样本数量与batchsize不能整除的情况下);num_workers是同时参与数据读取的线程数量,多线程技术可以加快数据读取,但是如果项目不大,设置为0就行,免得引入新问题;collate_fn,数据整理函数,这个方法可以被传递进去,把一个batch的数据变成你想要的样子。
  给一个自定义的collate_fn函数的例子:

def collate_fn(data):img = [i[0][0] for i in data]txt = [i[0][1] for i in data]labels = [i[1] for i in data]img = torch.stack(img).cuda()txt = torch.stack(txt).cuda()labels = torch.LongTensor(labels).cuda()return img, txt ,labels

那么我们的一般流程为:

# 创建Dateset(自定义)dataset = mydataset() 
# Dataset传递给DataLoaderdataloader = DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8,collate_fn=my_collate_fn)
# DataLoader迭代产生训练数据提供给模型for i in range(epoch):for index,(img,txt,label) in enumerate(dataloader):pass

这篇关于pytorch中的dataset和dataloader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/800212

相关文章

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

pytorch国内镜像源安装及测试

一、安装命令:  pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple  二、测试: import torchx = torch.rand(5, 3)print(x)

PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

文章目录 nn.MSELoss() 均方误差损失函数参数数学公式元素版本 要点附录 参考链接 nn.MSELoss() 均方误差损失函数 torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') Creates a criterion that measures the mean squared err

动手学深度学习(Pytorch版)代码实践 -计算机视觉-37微调

37微调 import osimport torchimport torchvisionfrom torch import nnimport liliPytorch as lpimport matplotlib.pyplot as pltfrom d2l import torch as d2l# 获取数据集d2l.DATA_HUB['hotdog'] = (d2l.DATA_U

WSL+Anconda(pytorch深度学习)环境配置

动机 最近在读point cloud相关论文,准备拉github上相应的code跑一下,但是之前没有深度学习的经验,在配置环境方面踩了超级多的坑,依次来记录一下。 一开始我直接将code拉到了windows本地来运行,遇到了数不清的问题(如:torch版本问题、numpy版本、bash命令无法运行等问题),经过请教,决定将project放到linux系统上进行运行。所以安装WSL(Window

动手学深度学习(Pytorch版)代码实践 -计算机视觉-36图像增广

6 图片增广 import matplotlib.pyplot as pltimport numpy as npimport torch import torchvisionfrom d2l import torch as d2lfrom torch import nn from PIL import Imageimport liliPytorch as lpfrom tor

pytorch 使用GPU加速常见的问题

pytorch如何使用gpu加速 print(torch.cuda.is_available())# 设置gpu设备device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# net使用GPUnet.to(device)# 数据copy到gpuinputData = inputData.to(devi

深度学习:关于损失函数的一些前置知识(PyTorch Loss)

在之前进行实验的时候发现:调用 Pytorch 中的 Loss 函数之前如果对其没有一定的了解,可能会影响实验效果和调试效率。以 CrossEntropyLoss 为例,最初设计实验的时候没有注意到该函数默认返回的是均值,以为是总和,于是最后计算完 Loss 之后,手动做了个均值,导致实际 Loss 被错误缩放,实验效果不佳,在后来 Debug 排除代码模型架构问题的时候才发觉这一点,着实花费了

《PyTorch计算机视觉实战》:一、二章

目录 第一章:人工神经网络基础 比较人工智能和传统机器学习 人工神经网络(Artificial Neural Network,ANN) 是一种受人类大脑运作方式启发而构建的监督学习算法。神经网络与人类大脑中神经元连接和激活的方式比较类似,神经网络接收输入并通过一个函数传递,导致随后的某些神经元被激活,从而产生输出。 有几种标准的 ANN 架构。通用近似定理认为,总是可以找到一