DataLoader 的 collate_fn 解释与示例教程

2024-04-09 04:28

本文主要是介绍DataLoader 的 collate_fn 解释与示例教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 导包
    • 数据
    • Dataloader
    • collate_fn

导包

import torch
from torch.utils.data import Dataset
from typing import Any

数据

class CustomDataset(Dataset):def __init__(self, length) -> None:super().__init__()self.length = lengthdef __getitem__(self, index=None):w1 = 3.14w2 = 4.27w = torch.tensor([w1, w2])feature = torch.rand(2) * 10noise = torch.randn_like(feature) * 0.01label = torch.matmul(w, feature.t())feature += noise# return feature, label.view(1)return feature, labeldef __len__(self):return self.lengthdataset = CustomDataset(4)

Dataloader

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, )for feature, label in dataloader:print(feature.shape, label.shape)

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):feature, label = zip(*item)feature = torch.stack(feature)label = torch.stack(label)label = label.view(-1, 1)return feature, labeldataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)for feature, label in dataloader:print(feature.shape, label.shape)

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

这篇关于DataLoader 的 collate_fn 解释与示例教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/887174

相关文章

使用Redis实现会话管理的示例代码

《使用Redis实现会话管理的示例代码》文章介绍了如何使用Redis实现会话管理,包括会话的创建、读取、更新和删除操作,通过设置会话超时时间并重置,可以确保会话在用户持续活动期间不会过期,此外,展示了... 目录1. 会话管理的基本概念2. 使用Redis实现会话管理2.1 引入依赖2.2 会话管理基本操作

mybatis-plus分表实现案例(附示例代码)

《mybatis-plus分表实现案例(附示例代码)》MyBatis-Plus是一个MyBatis的增强工具,在MyBatis的基础上只做增强不做改变,为简化开发、提高效率而生,:本文主要介绍my... 目录文档说明数据库水平分表思路1. 为什么要水平分表2. 核心设计要点3.基于数据库水平分表注意事项示例

Mybatis的mapper文件中#和$的区别示例解析

《Mybatis的mapper文件中#和$的区别示例解析》MyBatis的mapper文件中,#{}和${}是两种参数占位符,核心差异在于参数解析方式、SQL注入风险、适用场景,以下从底层原理、使用场... 目录MyBATis 中 mapper 文件里 #{} 与 ${} 的核心区别一、核心区别对比表二、底

Python中Request的安装以及简单的使用方法图文教程

《Python中Request的安装以及简单的使用方法图文教程》python里的request库经常被用于进行网络爬虫,想要学习网络爬虫的同学必须得安装request这个第三方库,:本文主要介绍P... 目录1.Requests 安装cmd 窗口安装为pycharm安装在pycharm设置中为项目安装req

HTML5的input标签的`type`属性值详解和代码示例

《HTML5的input标签的`type`属性值详解和代码示例》HTML5的`input`标签提供了多种`type`属性值,用于创建不同类型的输入控件,满足用户输入的多样化需求,从文本输入、密码输入、... 目录一、引言二、文本类输入类型2.1 text2.2 password2.3 textarea(严格

MySQL中between and的基本用法、范围查询示例详解

《MySQL中betweenand的基本用法、范围查询示例详解》BETWEENAND操作符在MySQL中用于选择在两个值之间的数据,包括边界值,它支持数值和日期类型,示例展示了如何使用BETWEEN... 目录一、between and语法二、使用示例2.1、betwphpeen and数值查询2.2、be

python中的flask_sqlalchemy的使用及示例详解

《python中的flask_sqlalchemy的使用及示例详解》文章主要介绍了在使用SQLAlchemy创建模型实例时,通过元类动态创建实例的方式,并说明了如何在实例化时执行__init__方法,... 目录@orm.reconstructorSQLAlchemy的回滚关联其他模型数据库基本操作将数据添

Java数组动态扩容的实现示例

《Java数组动态扩容的实现示例》本文主要介绍了Java数组动态扩容的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1 问题2 方法3 结语1 问题实现动态的给数组添加元素效果,实现对数组扩容,原始数组使用静态分配

JAVA项目swing转javafx语法规则以及示例代码

《JAVA项目swing转javafx语法规则以及示例代码》:本文主要介绍JAVA项目swing转javafx语法规则以及示例代码的相关资料,文中详细讲解了主类继承、窗口创建、布局管理、控件替换、... 目录最常用的“一行换一行”速查表(直接全局替换)实际转换示例(JFramejs → JavaFX)迁移建

JavaWeb项目创建、部署、连接数据库保姆级教程(tomcat)

《JavaWeb项目创建、部署、连接数据库保姆级教程(tomcat)》:本文主要介绍如何在IntelliJIDEA2020.1中创建和部署一个JavaWeb项目,包括创建项目、配置Tomcat服务... 目录简介:一、创建项目二、tomcat部署1、将tomcat解压在一个自己找得到路径2、在idea中添加