python使用DataLoader对数据集进行批处理

2024-02-10 04:58

本文主要是介绍python使用DataLoader对数据集进行批处理,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用DataLoader对数据集进行批处理,转自

https://www.cnblogs.com/JeasonIsCoding/p/10168753.html

第一步:创建torch能够识别的数据集类型

首先建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

import torch
import torch.utils.data as DataBATCH_SIZE = 3 		# 批训练的数据个数x = torch.linspace(1,9,9)  # x data (torch tensor)
y = torch.linspace(9,1,9)  # y data (torch tensor)

随后把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:

# 先转换成 torch 能够识别的 Dataset
torch_dataset = Data.TensorDataset( x, y )

现在来看一下这些数据的数据类型:

In [1]:  type(torch_dataset)
out[1]:  torch.utils.data.dataset.TensorDatasetIn [2]:  type(x)
out[2]:  torch.Tensor

可以看出X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是TensorDataset

第二步:把上一步的数据集放入Data.DataLoader中,生成一个迭代器,从而方便进行批处理

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = True, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

DataLoader中也有很多其他参数:

dataset:		Dataset类型,从其中加载数据 
batch_size:	int,可选。每个batch加载多少样本 
shuffle:		bool,可选。为True时表示每个epoch都对数据进行洗牌 
sampler:		Sampler,可选。从数据集中采样样本的方法。 
num_workers:	int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。 
collate_fn:	callable,可选。 
pin_memory:	bool,可选 
drop_last:		bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

第三步:用上面定义好的迭代器进行训练

这里利用print来模拟训练过程:

for epoch in range(5): 		# 训练所有数据5次i = 0for batch_x,batch_y in loader:i = i+1print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

为了便于观察分批的结果,这里设置:

shuffle = False, # 是否打乱数据

即:

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = False, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

输出结果是:

Epoch:0|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:0|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:0|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:1|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:1|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:1|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:2|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:2|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:2|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:3|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:3|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:3|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:4|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:4|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:4|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])

可以看到,所有数据一共训练了5次。数据中一共9组,设置的mini-batch是3,即每一次训练网络的时候送入3组数据。

此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:

for epoch in range(5): 		# 训练所有数据5次i = 0for step,(batch_x,batch_y) in enumerate(loader):# 假设这里在进行训练i = i+1# 打印一些数据print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

这篇关于python使用DataLoader对数据集进行批处理的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/696177

相关文章

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

C#中Guid类使用小结

《C#中Guid类使用小结》本文主要介绍了C#中Guid类用于生成和操作128位的唯一标识符,用于数据库主键及分布式系统,支持通过NewGuid、Parse等方法生成,感兴趣的可以了解一下... 目录前言一、什么是 Guid二、生成 Guid1. 使用 Guid.NewGuid() 方法2. 从字符串创建

Python使用python-can实现合并BLF文件

《Python使用python-can实现合并BLF文件》python-can库是Python生态中专注于CAN总线通信与数据处理的强大工具,本文将使用python-can为BLF文件合并提供高效灵活... 目录一、python-can 库:CAN 数据处理的利器二、BLF 文件合并核心代码解析1. 基础合

Python使用OpenCV实现获取视频时长的小工具

《Python使用OpenCV实现获取视频时长的小工具》在处理视频数据时,获取视频的时长是一项常见且基础的需求,本文将详细介绍如何使用Python和OpenCV获取视频时长,并对每一行代码进行深入解析... 目录一、代码实现二、代码解析1. 导入 OpenCV 库2. 定义获取视频时长的函数3. 打开视频文

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Python设置Cookie永不超时的详细指南

《Python设置Cookie永不超时的详细指南》Cookie是一种存储在用户浏览器中的小型数据片段,用于记录用户的登录状态、偏好设置等信息,下面小编就来和大家详细讲讲Python如何设置Cookie... 目录一、Cookie的作用与重要性二、Cookie过期的原因三、实现Cookie永不超时的方法(一)

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客