[pytorch] 定义自己的dataloader

2024-01-30 10:36
文章标签 定义 pytorch dataloader

本文主要是介绍[pytorch] 定义自己的dataloader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

定义自己的dataloader

  • 1 定义datalaoder
    • 1.1 __init__
    • 1.2 __getitem__
    • 1.3 __len__
  • 2 调用dataloader
  • 参考

在使用自己数据集训练网络时,往往需要定义自己的dataloader。

1 定义datalaoder

一般将dataloader封装为一个类,这个类继承自 torch.utils.data.dataset

from torch.utils.data import datasetclass LoadData(Dataset):  # 注意父类的名称,不能写datasetpass

需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成 dataset.Dataset。

在我们定义的LoadData中,至少需要有三个方法:

  • __init__方法,主要用来定义数据的预处理
  • __getitem__方法,返回数据的item和label
  • __len__方法,返回数据个数

整体大致架构:

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoaderclass LoadData(dDataset):def __init__(self):passdef __getitem__(self,index):passdef __len__(self):passdataset = Loaddata()
train_loader = DataLoader(dataset = dataset,batch_size = 32,shuffle = Ture,num_workers=2)

1.1 init

__init__方法需要传入至少两个参数:

  • 一般数据的地址和标签已经被保存在某个文档中了(这里是txt格式的文档)。因此需要传入这个文档的地址。
  • 因为__init__方法要做预处理,一般用来train的预处理和test的预处理是不同的,因此需要区分二者的参数。
def __init__(self, txt_path, train=True):super(LoadData, self).__init__()self.img_info = self.get_img(txt_path)self.train = train# train预处理self.train_transforms = transforms.Compose([transforms.Resize(20),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])# test预处理self.test_transforms = transforms.Compose([transforms.Resize(20),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])# 这个函数是用来读txt文档的def get_img(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))return imgs_info

1.2 getitem

__getitem__方法只需要根据index返回数据的item和label。

def __getitem__(self, index):img_path, label = self.img_info[index]img = Image.open(img_path)label = int(label)# 注意区分预处理if self.train:img = self.train_transforms(img)else:img = self.test_transforms(img)return img, label

1.3 len

__len__方法最简单,仅返回数据项个数。

def __len__(self):return len(self.img_info)

2 调用dataloader

以训练数据为例,调用dataloader需要两步:

  • 将自定义的LoadData实例化
  • 传入torch.utils.data.dataloader中
from torch.utils.data import Dataloadertrain_dataset = LoadData(txt_path='XXXX', train=True)train_loader = dataloader.Dataloader(dataset=train_dataset,batch_size=8,shuffle=True)

至此,一个最简单的dataloader就完成了!
可以用以下代码测试:

for image, label in train_loader:print(image.shape)print(label)

参考

https://zhuanlan.zhihu.com/p/399447239

这篇关于[pytorch] 定义自己的dataloader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/659959

相关文章

CSS Anchor Positioning重新定义锚点定位的时代来临(最新推荐)

《CSSAnchorPositioning重新定义锚点定位的时代来临(最新推荐)》CSSAnchorPositioning是一项仍在草案中的新特性,由Chrome125开始提供原生支持需... 目录 css Anchor Positioning:重新定义「锚定定位」的时代来了! 什么是 Anchor Pos

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

C 语言中enum枚举的定义和使用小结

《C语言中enum枚举的定义和使用小结》在C语言里,enum(枚举)是一种用户自定义的数据类型,它能够让你创建一组具名的整数常量,下面我会从定义、使用、特性等方面详细介绍enum,感兴趣的朋友一起看... 目录1、引言2、基本定义3、定义枚举变量4、自定义枚举常量的值5、枚举与switch语句结合使用6、枚

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch