DataLoader自定义数据集制作

2024-04-10 03:12

本文主要是介绍DataLoader自定义数据集制作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

如何自定义数据集:

- 1.数据和标签的目录结构先搞定(得知道到哪读数据)
- 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
- 3.完成单个数据与标签读取函数(给dataloader举一个例子)

以花朵数据集为例:

- 原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)
- 这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
- 核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦

任务1:读取txt文件中的路径和标签

  • 第一个小任务,从标注文件中读取数据和标签
  • 至于你准备存成什么格式,都可以的,一会能取出来东西就行

任务2:分别把数据和标签都存在list里

 - 不是我非让你存list里,因为dataloader到时候会在这里取数据
- 按照人家要求来,不要耍个性,让整list咱就给人家整

 任务3:图像数据路径得完整
- 因为一会咱得用这个路径去读数据,所以路径得加上前缀
- 以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像

任务4:把上面那几个事得写在一起

- 1.注意要使用from torch.utils.data import Dataset, DataLoader
- 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
- 3.def __init__(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
- 4.def __getitem__(self, idx):根据自己任务,返回图像数据和标签数据

任务5:数据预处理(transform)

- 1.预处理的事都在上面的__getitem__中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整
- 2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
- 3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法

 任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

- 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
- 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
- 3.打印看看数据里面是不是有东西了

 

任务7:用之前先试试,整个数据和标签对应下

- 1.别着急往模型里传,对不对都不知道呢
- 2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
- 3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题

 代码实现

import osfrom matplotlib import pyplot as plt
from torchvision import transforms, models, datasets
import numpy as np
import torch
from PIL import Imagedef load_annotations(ann_file):data_infos = {}with open(ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infosimg_label =load_annotations('./flower_data/train.txt')
image_name = list(img_label.keys())
label = list(img_label.values())data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'image_path = [os.path.join(train_dir,img) for img in image_name]from torch.utils.data import Dataset, DataLoaderclass FlowerDataset(Dataset):def __init__(self, root_dir, ann_file, transform=None):self.ann_file = ann_fileself.root_dir = root_dirself.img_label = self.load_annotations()self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]self.label = [label for label in list(self.img_label.values())]self.transform = transformdef __len__(self):return len(self.img)def __getitem__(self, idx):image = Image.open(self.img[idx])label = self.label[idx]if self.transform:image = self.transform(image)label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):data_infos = {}with open(self.ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infosdata_transforms = {'train':transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid':transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)image, label = next(iter(train_loader))
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))

 

这篇关于DataLoader自定义数据集制作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/889929

相关文章

使用Python制作一个PDF批量加密工具

《使用Python制作一个PDF批量加密工具》PDF批量加密‌是一种保护PDF文件安全性的方法,通过为多个PDF文件设置相同的密码,防止未经授权的用户访问这些文件,下面我们来看看如何使用Python制... 目录1.简介2.运行效果3.相关源码1.简介一个python写的PDF批量加密工具。PDF批量加密

用Unity2D制作一个人物,实现移动、跳起、人物静止和动起来时的动画:中(人物移动、跳起、静止动作)

上回我们学到创建一个地形和一个人物,今天我们实现一下人物实现移动和跳起,依次点击,我们准备创建一个C#文件 创建好我们点击进去,就会跳转到我们的Vision Studio,然后输入这些代码 using UnityEngine;public class Move : MonoBehaviour // 定义一个名为Move的类,继承自MonoBehaviour{private Rigidbo

react笔记 8-16 JSX语法 定义数据 数据绑定

1、jsx语法 和vue一样  只能有一个根标签 一行代码写法 return <div>hello world</div> 多行代码返回必须加括号 return (<div><div>hello world</div><div>aaaaaaa</div></div>) 2、定义数据 数据绑定 constructor(){super()this.state={na

OpenStack离线Train版安装系列—0制作yum源

本系列文章包含从OpenStack离线源制作到完成OpenStack安装的全部过程。 在本系列教程中使用的OpenStack的安装版本为第20个版本Train(简称T版本),2020年5月13日,OpenStack社区发布了第21个版本Ussuri(简称U版本)。 OpenStack部署系列文章 OpenStack Victoria版 安装部署系列教程 OpenStack Ussuri版

OpenStack镜像制作系列5—Linux镜像

本系列文章主要对如何制作OpenStack镜像的过程进行描述记录 CSDN:OpenStack镜像制作教程指导(全) OpenStack镜像制作系列1—环境准备 OpenStack镜像制作系列2—Windows7镜像 OpenStack镜像制作系列3—Windows10镜像 OpenStack镜像制作系列4—Windows Server2019镜像 OpenStack镜像制作

OpenStack镜像制作系列4—Windows Server2019镜像

本系列文章主要对如何制作OpenStack镜像的过程进行描述记录  CSDN:OpenStack镜像制作教程指导(全) OpenStack镜像制作系列1—环境准备 OpenStack镜像制作系列2—Windows7镜像 OpenStack镜像制作系列3—Windows10镜像 OpenStack镜像制作系列4—Windows Server2019镜像 OpenStack镜像制作系

OpenStack镜像制作系列2—Windows7镜像

本系列文章主要对如何制作OpenStack镜像的过程进行描述记录 CSDN:OpenStack镜像制作教程指导(全) OpenStack镜像制作系列1—环境准备 OpenStack镜像制作系列2—Windows7镜像 OpenStack镜像制作系列3—Windows10镜像 OpenStack镜像制作系列4—Windows Server2019镜像 OpenStack镜像制作系列

OpenStack镜像制作系列1—环境准备

本系列文章主要对如何制作OpenStack镜像的过程进行描述记录 CSDN:OpenStack镜像制作教程指导(全) OpenStack镜像制作系列1—环境准备 OpenStack镜像制作系列2—Windows7镜像 OpenStack镜像制作系列3—Windows10镜像 OpenStack镜像制作系列4—Windows Server2019镜像 OpenStack镜像制作

CSDN:OpenStack镜像制作教程指导(全)

本系列文章主要对如何制作OpenStack镜像的过程进行描述记录,涉及基本环境准备、常见类型操作系统的镜像制作。 让你可以从零开始安装一个操作系统,并支持个性化制作OpenStack镜像。 CSDN:OpenStack镜像制作教程指导(全) OpenStack镜像制作系列1—环境准备 OpenStack镜像制作系列2—Windows7镜像 OpenStack镜像制作系列3—Windows

docker学习系列(四)制作基础的base项目镜像--jdk+tomcat

前面已经完成了docker的安装以及使用,现在我们要将自己的javaweb项目与docker结合 1.1准备jdk+tomcat软件 ​​我下载了apache-tomcat-7.0.68.tar.gz和jdk-7u79-linux-x64.tar.gz,存储于Linux机器的本地目录/usr/ect/wt/下(利用xshell上传)。利用linux命令 tar -zxvf apache-tom