Pytorch打怪路(三)Pytorch创建自己的数据集1

2024-03-23 06:32
文章标签 数据 创建 pytorch 打怪

本文主要是介绍Pytorch打怪路(三)Pytorch创建自己的数据集1,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

之前讲的例子,程序都是调用的datasets方法,下载的torchvision本身就提供的数据,那么如果想导入自己的数据应该怎么办呢?

本篇就讲解一下如何创建自己的数据集。

还有第二篇……Pytorch打怪路(三)Pytorch创建自己的数据集2

1.用于分类的数据集

以mnist数据集为例

这里的mnist数据集并不是torchvision里面的,而是我自己的以图片格式保存的数据集,因为我在测试STN时,希望自己再把这些手写体做一些形变,

所以就先把MNIST数据集转化成了jpg图片格式,然后做了一些形变,当然这不是重点。首先我们看一下我的数据集的情况:

如图所示,我的图片数据集确实是jpg图片

 

再看我的存储图片名和label信息的文本:

 

 

如图所示,我的mnist.txt文本每一行分为两部分,第一部分是具体路径+图片名.jpg

第二部分就是label信息,因为前面这部分图片都是0 ,所以他们的分类的label信息就是0

要创建你自己的 用于分类的 数据集,也要包含上述两个部分,1.图片数据集,2.文本信息(这个txt文件可以用python或者C++轻易创建,再此不详述)

2.代码

 

主要代码

from PIL import Image
import torchclass MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Datasetdef __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要传入的参数fh = open(root + datatxt, 'r') #按照传入的路径和txt文本参数,打开这个文本,并读取内容imgs = []                      #创建一个名为img的空列表,一会儿用来装东西for line in fh:                #按行循环txt文本中的内容line = line.rstrip()       # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询pythonwords = line.split()   #通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等imgs.append((words[0],int(words[1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定# 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lableself.imgs = imgsself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息img = Image.open(root+fn).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片if self.transform is not None:img = self.transform(img) #是否进行transformreturn img,label  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分return len(self.imgs)#根据自己定义的那个勒MyDataset来创建数据集!注意是数据集!而不是loader迭代器
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

 

再补充一点代码,以便更好的理解 __getitem__这个方法

 

 

for batch_index, data, target in test_loader:if use_cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data, volatile=True), Variable(target)

这段代码是我从测试的部分中截取出来的,为什么直接能用for data, target In test_loader这样的语句呢?

其实这个语句还可以这么写:

for batch_index, batch in train_loader

        data, target = batch

这样就好理解了,因为这个迭代器每一次循环所得的batch里面装的东西,就是我在__getitem__方法最后return回来的

所以你想在训练或者测试的时候还得到其他信息的话,就去增加一些返回值即可,只要是能return出来的,就能在每个batch中读取到!

###############################################################################

有朋友可能想问,如果我的label信息不是数字而是图像呢?比如分割任务,它的label就是图像,这样的数据集的建立,也参考我的下一篇博文:

Pytorch打怪路(三)Pytorch创建自己的数据集2

这篇关于Pytorch打怪路(三)Pytorch创建自己的数据集1的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/837439

相关文章

Python获取中国节假日数据记录入JSON文件

《Python获取中国节假日数据记录入JSON文件》项目系统内置的日历应用为了提升用户体验,特别设置了在调休日期显示“休”的UI图标功能,那么问题是这些调休数据从哪里来呢?我尝试一种更为智能的方法:P... 目录节假日数据获取存入jsON文件节假日数据读取封装完整代码项目系统内置的日历应用为了提升用户体验,

Java利用JSONPath操作JSON数据的技术指南

《Java利用JSONPath操作JSON数据的技术指南》JSONPath是一种强大的工具,用于查询和操作JSON数据,类似于SQL的语法,它为处理复杂的JSON数据结构提供了简单且高效... 目录1、简述2、什么是 jsONPath?3、Java 示例3.1 基本查询3.2 过滤查询3.3 递归搜索3.4

idea中创建新类时自动添加注释的实现

《idea中创建新类时自动添加注释的实现》在每次使用idea创建一个新类时,过了一段时间发现看不懂这个类是用来干嘛的,为了解决这个问题,我们可以设置在创建一个新类时自动添加注释,帮助我们理解这个类的用... 目录前言:详细操作:步骤一:点击上方的 文件(File),点击&nbmyHIgsp;设置(Setti

MySQL大表数据的分区与分库分表的实现

《MySQL大表数据的分区与分库分表的实现》数据库的分区和分库分表是两种常用的技术方案,本文主要介绍了MySQL大表数据的分区与分库分表的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. mysql大表数据的分区1.1 什么是分区?1.2 分区的类型1.3 分区的优点1.4 分

Mysql删除几亿条数据表中的部分数据的方法实现

《Mysql删除几亿条数据表中的部分数据的方法实现》在MySQL中删除一个大表中的数据时,需要特别注意操作的性能和对系统的影响,本文主要介绍了Mysql删除几亿条数据表中的部分数据的方法实现,具有一定... 目录1、需求2、方案1. 使用 DELETE 语句分批删除2. 使用 INPLACE ALTER T

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

Redis 中的热点键和数据倾斜示例详解

《Redis中的热点键和数据倾斜示例详解》热点键是指在Redis中被频繁访问的特定键,这些键由于其高访问频率,可能导致Redis服务器的性能问题,尤其是在高并发场景下,本文给大家介绍Redis中的热... 目录Redis 中的热点键和数据倾斜热点键(Hot Key)定义特点应对策略示例数据倾斜(Data S

Python实现将MySQL中所有表的数据都导出为CSV文件并压缩

《Python实现将MySQL中所有表的数据都导出为CSV文件并压缩》这篇文章主要为大家详细介绍了如何使用Python将MySQL数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到... python将mysql数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到另一个

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

SpringBoot整合jasypt实现重要数据加密

《SpringBoot整合jasypt实现重要数据加密》Jasypt是一个专注于简化Java加密操作的开源工具,:本文主要介绍详细介绍了如何使用jasypt实现重要数据加密,感兴趣的小伙伴可... 目录jasypt简介 jasypt的优点SpringBoot使用jasypt创建mapper接口配置文件加密