本文主要是介绍DataLoader 的 collate_fn 解释与示例教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 导包
- 数据
- Dataloader
- collate_fn
导包
import torch
from torch.utils.data import Dataset
from typing import Any
数据
class CustomDataset(Dataset):def __init__(self, length) -> None:super().__init__()self.length = lengthdef __getitem__(self, index=None):w1 = 3.14w2 = 4.27w = torch.tensor([w1, w2])feature = torch.rand(2) * 10noise = torch.randn_like(feature) * 0.01label = torch.matmul(w, feature.t())feature += noise# return feature, label.view(1)return feature, labeldef __len__(self):return self.lengthdataset = CustomDataset(4)
Dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, )for feature, label in dataloader:print(feature.shape, label.shape)
下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature)
,构建出 batch 数据;
torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])
常量直接拼接;
向量则会在前面添加一个 batch 纬度;
collate_fn
collate_fn
:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;
如上述默认的输出结果所示:label.shape
为 torch.Size([2]),笔者想通过 collate_fn
修改 label.shape
为torch.Size([2, 1])
,下述代码实现这个功能:
def collate_fn(item):feature, label = zip(*item)feature = torch.stack(feature)label = torch.stack(label)label = label.view(-1, 1)return feature, labeldataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)for feature, label in dataloader:print(feature.shape, label.shape)
输出如下:
torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])
在collate_fn(item)
,传入的item的数据为:
[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]
feature, label = zip(*item)
故通过zip(*item)
的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack
方法将其拼接出 batch 形状的数据。
这篇关于DataLoader 的 collate_fn 解释与示例教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!