【DETR系列目标检测算法代码精讲】01 DETR算法03 Dataloader代码精讲

2024-04-02 02:04

本文主要是介绍【DETR系列目标检测算法代码精讲】01 DETR算法03 Dataloader代码精讲,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

与一般的Dataloader的区别在于我们对图像进行了随机裁剪,需要进行额外的操作才能将其打包到dataloader里面

在这里插入图片描述
这一段的代码如下:

    if args.distributed:sampler_train = DistributedSampler(dataset_train)sampler_val = DistributedSampler(dataset_val, shuffle=False)else:sampler_train = torch.utils.data.RandomSampler(dataset_train)sampler_val = torch.utils.data.SequentialSampler(dataset_val)batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)data_loader_train = DataLoader(dataset_train,batch_sampler=batch_sampler_train,collate_fn=utils.collate_fn,# num_workers=args.num_workers)data_loader_val = DataLoader(dataset_val,args.batch_size,sampler=sampler_val,drop_last=False,collate_fn=utils.collate_fn,# num_workers=args.num_workers)

对于训练数据集,使用RandowSampler类进行随机采样
对于验证数据集,使用SequentialSampler进行顺序采样

采样以后,使用BatchSampler打包成batch
然后再使用Dataloader

Dataloader中有个函数
collate_fn

在这里插入图片描述
这个函数又调用了一个函数nested_tensor_from_tensor_list
这个函数重新定义了我们输入数据的格式

在这里插入图片描述

默认的batch为2,我们输入的就是包含了两个元素的list,其中每个元素都是我们从dataset的__getitem__方法获得的输出

然后通过zip函数进行解析
在这里插入图片描述
可以看到之前的形式是一个元素中是img+target
现在变成了一个元素里面都是img,另一个元素里面都是target
batch[0] = nested_tensor_from_tensor_list(batch[0])
然后通过索引0取出图像部分

传入到nested_tensor_from_tensor_list方法中
这个方法的全部代码如下

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):# TODO make this more generalif tensor_list[0].ndim == 3:if torchvision._is_tracing():# nested_tensor_from_tensor_list() does not export well to ONNX# call _onnx_nested_tensor_from_tensor_list() insteadreturn _onnx_nested_tensor_from_tensor_list(tensor_list)# TODO make it support different-sized imagesmax_size = _max_by_axis([list(img.shape) for img in tensor_list])batch_shape = [len(tensor_list)] + max_sizeb, c, h, w = batch_shapedtype = tensor_list[0].dtypedevice = tensor_list[0].devicetensor = torch.zeros(batch_shape, dtype=dtype, device=device)mask = torch.ones((b, h, w), dtype=torch.bool, device=device)for img, pad_img, m in zip(tensor_list, tensor, mask):pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)m[: img.shape[1], :img.shape[2]] = Falseelse:raise ValueError('not supported')return NestedTensor(tensor, mask)

定义了一个函数nested_tensor_from_tensor_list,该函数接受一个Tensor列表(tensor_list)作为输入,并返回一个NestedTensor对象。NestedTensor是一个特殊的数据结构,通常用于表示图像数据,其中可以包含不同大小的图像,并且有一个与之对应的掩码(mask)来表示每个图像的实际大小。

检查tensor_list中的第一个Tensor的维度是否为3。

通过_max_by_axis函数来确定tensor_list中所有图像的最大尺寸。这意味着最终生成的NestedTensor将包含所有图像的最大高度和宽度。

这个_max_by_axis函数的代码如下:

def _max_by_axis(the_list):# type: (List[List[int]]) -> List[int]maxes = the_list[0]for sublist in the_list[1:]:for index, item in enumerate(sublist):maxes[index] = max(maxes[index], item)return maxes

在这里插入图片描述
对于归一化处理后的图像,我们的输入是一个三维的矩阵

在这里插入图片描述
在这里插入图片描述
将两个三维矩阵传入这个方法

这个list有两个元素,maxes列表首先被初始化为第一个子列表
遍历the_list中除了第一个子列表之外的所有子列表,
对于当前子列表中的每个元素,我们将其与maxes列表中对应索引的当前最大值进行比较。我们使用max函数来确定这两个值中的较大值,并将其赋值给maxes列表的相应位置。

进行比较的这个索引值就是宽和高的值

比如
在这里插入图片描述
这里第二个子列表的的index是2,值是512
就要与第一个子列表的第2个值进行比较,就是911

所以输出的就是这个batch里面所有图像中最长的宽度和高度

这个尺寸就是这个batch最终的目标尺寸

接下来的操作就是需要将这个batch中的每一张图像加上padding
使得它们的尺寸都满足这个要求

在这里插入图片描述

在batch的维度前面加上batch中图像的个数

然后创造一个这个尺寸的底图 值全为0

将所有图像按照左上角点对齐的方式填充到这个底图上

在这里插入图片描述
再生成一个batch_size为2,宽和高分别为最大宽和高的全1矩阵

它的作用是记录图像中哪些部分是图像 哪些部分是padding
接下来通过循环记录图像中的每个位置,图像部分都记为false
表示这个位置不是padding

在这里插入图片描述
然后用输出的结果替换掉batch的第一个元素,也就是image的部分

在这里插入图片描述

这个时候输出的就是
在这里插入图片描述
以上就是dataloader的部分

这篇关于【DETR系列目标检测算法代码精讲】01 DETR算法03 Dataloader代码精讲的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/868691

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 3974 and hdu 3068 最长回文串的O(n)解法(Manacher算法)

求一段字符串中的最长回文串。 因为数据量比较大,用原来的O(n^2)会爆。 小白上的O(n^2)解法代码:TLE啦~ #include<stdio.h>#include<string.h>const int Maxn = 1000000;char s[Maxn];int main(){char e[] = {"END"};while(scanf("%s", s) != EO

hdu 2602 and poj 3624(01背包)

01背包的模板题。 hdu2602代码: #include<stdio.h>#include<string.h>const int MaxN = 1001;int max(int a, int b){return a > b ? a : b;}int w[MaxN];int v[MaxN];int dp[MaxN];int main(){int T;int N, V;s

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n