[pytorch] 定义自己的dataloader

2024-01-30 10:36
文章标签 定义 pytorch dataloader

本文主要是介绍[pytorch] 定义自己的dataloader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

定义自己的dataloader

  • 1 定义datalaoder
    • 1.1 __init__
    • 1.2 __getitem__
    • 1.3 __len__
  • 2 调用dataloader
  • 参考

在使用自己数据集训练网络时,往往需要定义自己的dataloader。

1 定义datalaoder

一般将dataloader封装为一个类,这个类继承自 torch.utils.data.dataset

from torch.utils.data import datasetclass LoadData(Dataset):  # 注意父类的名称,不能写datasetpass

需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成 dataset.Dataset。

在我们定义的LoadData中,至少需要有三个方法:

  • __init__方法,主要用来定义数据的预处理
  • __getitem__方法,返回数据的item和label
  • __len__方法,返回数据个数

整体大致架构:

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoaderclass LoadData(dDataset):def __init__(self):passdef __getitem__(self,index):passdef __len__(self):passdataset = Loaddata()
train_loader = DataLoader(dataset = dataset,batch_size = 32,shuffle = Ture,num_workers=2)

1.1 init

__init__方法需要传入至少两个参数:

  • 一般数据的地址和标签已经被保存在某个文档中了(这里是txt格式的文档)。因此需要传入这个文档的地址。
  • 因为__init__方法要做预处理,一般用来train的预处理和test的预处理是不同的,因此需要区分二者的参数。
def __init__(self, txt_path, train=True):super(LoadData, self).__init__()self.img_info = self.get_img(txt_path)self.train = train# train预处理self.train_transforms = transforms.Compose([transforms.Resize(20),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])# test预处理self.test_transforms = transforms.Compose([transforms.Resize(20),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])# 这个函数是用来读txt文档的def get_img(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))return imgs_info

1.2 getitem

__getitem__方法只需要根据index返回数据的item和label。

def __getitem__(self, index):img_path, label = self.img_info[index]img = Image.open(img_path)label = int(label)# 注意区分预处理if self.train:img = self.train_transforms(img)else:img = self.test_transforms(img)return img, label

1.3 len

__len__方法最简单,仅返回数据项个数。

def __len__(self):return len(self.img_info)

2 调用dataloader

以训练数据为例,调用dataloader需要两步:

  • 将自定义的LoadData实例化
  • 传入torch.utils.data.dataloader中
from torch.utils.data import Dataloadertrain_dataset = LoadData(txt_path='XXXX', train=True)train_loader = dataloader.Dataloader(dataset=train_dataset,batch_size=8,shuffle=True)

至此,一个最简单的dataloader就完成了!
可以用以下代码测试:

for image, label in train_loader:print(image.shape)print(label)

参考

https://zhuanlan.zhihu.com/p/399447239

这篇关于[pytorch] 定义自己的dataloader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/659959

相关文章

通俗范畴论4 范畴的定义

注:由于CSDN无法显示本文章源文件的公式,因此部分下标、字母花体、箭头表示可能会不正常,请读者谅解 范畴的正式定义 上一节我们在没有引入范畴这个数学概念的情况下,直接体验了一个“苹果1”范畴,建立了一个对范畴的直观。本节我们正式学习范畴的定义和基本性质。 一个范畴(Category) C𝐶,由以下部分组成: 数据: 对象(Objects):包含若干个对象(Objects),这些

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

利用结构体作为函数参数时结构体指针的定义

在利用结构体作为函数的参数进行传递时,容易犯的一个错误是将一个野指针传给函数导致错误。 #include <stdio.h>#include <math.h>#include <malloc.h>#define MAXSIZE 10typedef struct {int r[MAXSIZE]; //用于存储要排序的数组,r[0]作为哨兵或者临时变量int length;

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

linux cron /etc/crontab 及 /var/spool/cron/$USER 中定义定时任务

简介 定时任务在linux上主要体现在两个地方,一个是/etc/crontab ,另一个就是定义了任务计划的用户/var/spool/cron/$USER 1、crontab -e 或者直接编辑/etc/crontab文件,这种方式用的人比较多,/etc/crontab是系统调度的配置文件,只有root用户可以使用,使用时需root权限,而且必须指定运行用户,才会执行 * * * * * *

vue+elementui搭建后台管理界面(5递归生成侧栏路由) vue定义定义多级路由菜单

有一个菜单树,顶层菜单下面有多个子菜单,子菜单下还有子菜单。。。 这时候就要用递归处理 1 定义多级菜单 修改 src/router/index.js 的 / 路由 {path: '/',redirect: '/dashboard',name: 'Container',component: Container,children: [{path: 'dashboard', name: '首

pytorch国内镜像源安装及测试

一、安装命令:  pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple  二、测试: import torchx = torch.rand(5, 3)print(x)

PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

文章目录 nn.MSELoss() 均方误差损失函数参数数学公式元素版本 要点附录 参考链接 nn.MSELoss() 均方误差损失函数 torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') Creates a criterion that measures the mean squared err

动手学深度学习(Pytorch版)代码实践 -计算机视觉-37微调

37微调 import osimport torchimport torchvisionfrom torch import nnimport liliPytorch as lpimport matplotlib.pyplot as pltfrom d2l import torch as d2l# 获取数据集d2l.DATA_HUB['hotdog'] = (d2l.DATA_U

基础C语言知识串串香11☞宏定义与预处理、函数和函数库

​ 六、C语言宏定义与预处理、函数和函数库 6.1 编译工具链 源码.c ——> (预处理)——>预处理过的.i文件——>(编译)——>汇编文件.S——>(汇编)——>目标文件.o->(链接)——>elf可执行程序 预处理用预处理器,编译用编译器,汇编用汇编器,链接用链接器,这几个工具再加上其他一些额外的会用到的可用工具,合起来叫编译工具链(gcc就是一个编译工具链)。 gcc中各选项