python使用DataLoader对数据集进行批处理

2024-02-10 04:58

本文主要是介绍python使用DataLoader对数据集进行批处理,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用DataLoader对数据集进行批处理,转自

https://www.cnblogs.com/JeasonIsCoding/p/10168753.html

第一步:创建torch能够识别的数据集类型

首先建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

import torch
import torch.utils.data as DataBATCH_SIZE = 3 		# 批训练的数据个数x = torch.linspace(1,9,9)  # x data (torch tensor)
y = torch.linspace(9,1,9)  # y data (torch tensor)

随后把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:

# 先转换成 torch 能够识别的 Dataset
torch_dataset = Data.TensorDataset( x, y )

现在来看一下这些数据的数据类型:

In [1]:  type(torch_dataset)
out[1]:  torch.utils.data.dataset.TensorDatasetIn [2]:  type(x)
out[2]:  torch.Tensor

可以看出X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是TensorDataset

第二步:把上一步的数据集放入Data.DataLoader中,生成一个迭代器,从而方便进行批处理

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = True, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

DataLoader中也有很多其他参数:

dataset:		Dataset类型,从其中加载数据 
batch_size:	int,可选。每个batch加载多少样本 
shuffle:		bool,可选。为True时表示每个epoch都对数据进行洗牌 
sampler:		Sampler,可选。从数据集中采样样本的方法。 
num_workers:	int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。 
collate_fn:	callable,可选。 
pin_memory:	bool,可选 
drop_last:		bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

第三步:用上面定义好的迭代器进行训练

这里利用print来模拟训练过程:

for epoch in range(5): 		# 训练所有数据5次i = 0for batch_x,batch_y in loader:i = i+1print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

为了便于观察分批的结果,这里设置:

shuffle = False, # 是否打乱数据

即:

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = False, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

输出结果是:

Epoch:0|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:0|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:0|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:1|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:1|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:1|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:2|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:2|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:2|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:3|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:3|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:3|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:4|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:4|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:4|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])

可以看到,所有数据一共训练了5次。数据中一共9组,设置的mini-batch是3,即每一次训练网络的时候送入3组数据。

此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:

for epoch in range(5): 		# 训练所有数据5次i = 0for step,(batch_x,batch_y) in enumerate(loader):# 假设这里在进行训练i = i+1# 打印一些数据print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

这篇关于python使用DataLoader对数据集进行批处理的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/696177

相关文章

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Python将大量遥感数据的值缩放指定倍数的方法(推荐)

《Python将大量遥感数据的值缩放指定倍数的方法(推荐)》本文介绍基于Python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处理,并将所得处理后数据保存为新的遥感影像... 本文介绍基于python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

Mysql虚拟列的使用场景

《Mysql虚拟列的使用场景》MySQL虚拟列是一种在查询时动态生成的特殊列,它不占用存储空间,可以提高查询效率和数据处理便利性,本文给大家介绍Mysql虚拟列的相关知识,感兴趣的朋友一起看看吧... 目录1. 介绍mysql虚拟列1.1 定义和作用1.2 虚拟列与普通列的区别2. MySQL虚拟列的类型2

Python进阶之Excel基本操作介绍

《Python进阶之Excel基本操作介绍》在现实中,很多工作都需要与数据打交道,Excel作为常用的数据处理工具,一直备受人们的青睐,本文主要为大家介绍了一些Python中Excel的基本操作,希望... 目录概述写入使用 xlwt使用 XlsxWriter读取修改概述在现实中,很多工作都需要与数据打交

使用MongoDB进行数据存储的操作流程

《使用MongoDB进行数据存储的操作流程》在现代应用开发中,数据存储是一个至关重要的部分,随着数据量的增大和复杂性的增加,传统的关系型数据库有时难以应对高并发和大数据量的处理需求,MongoDB作为... 目录什么是MongoDB?MongoDB的优势使用MongoDB进行数据存储1. 安装MongoDB

关于@MapperScan和@ComponentScan的使用问题

《关于@MapperScan和@ComponentScan的使用问题》文章介绍了在使用`@MapperScan`和`@ComponentScan`时可能会遇到的包扫描冲突问题,并提供了解决方法,同时,... 目录@MapperScan和@ComponentScan的使用问题报错如下原因解决办法课外拓展总结@

mysql数据库分区的使用

《mysql数据库分区的使用》MySQL分区技术通过将大表分割成多个较小片段,提高查询性能、管理效率和数据存储效率,本文就来介绍一下mysql数据库分区的使用,感兴趣的可以了解一下... 目录【一】分区的基本概念【1】物理存储与逻辑分割【2】查询性能提升【3】数据管理与维护【4】扩展性与并行处理【二】分区的

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超