Faster R-CNN Keras版源码史上最详细解读系列之RPN训练数据处理一

本文主要是介绍Faster R-CNN Keras版源码史上最详细解读系列之RPN训练数据处理一,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Faster R-CNN Keras版源码史上最详细解读系列之RPN训练数据处理一

  • 训练数据处理
    • 训练数据处理

训练数据处理

前面我们将了RPN模型,同时包含特征提取的,输入是图片,输出是分类和回归,我们现在有了模型的预测输出,因为做的是有监督学习,所以我们还需要真实值输出,也就是标注框相关的分类和回归部分,以便于去计算损失。还是train_frcnn.py

            # 图片,rpn的分类和回归,增强后的图片数据X, Y, img_data = next(data_gen_train)# 返回三个损失 总得loss rpn_loss_cls  rpn_loss_regrloss_rpn = model_rpn.train_on_batch(X, Y)

上面的Y就是真实的分类和回归,因为要统一成RPN模型的输出格式才可以进行损失计算,所以我们需要把他们预处理一下,我们通过data_gen_train迭代器来获取预处理后的数据,每次就一张图片。

# 获取真实的标注训练数据
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length,K.image_dim_ordering(), mode='train')
# 获取真实的标注测试数据
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,K.image_dim_ordering(), mode='val')

可以看到迭代器其实是这个函数data_generators.get_anchor_gt,就是获取真实框的预处理信息,下面我们来看看这个方法。

训练数据处理

来看看这个文件data_generators.pySampleSelector

# 样本选择器
class SampleSelector:def __init__(self, class_count):# ignore classes that have zero samples# 获取所有类别名的序列,除去个数是0的,针对bgself.classes = [b for b in class_count.keys() if class_count[b] > 0]# 把传入的序列无限重复下去 比如序列 ABC ,重复就是 ABCABCBC... 这样是为了实现样本均衡,所有类别比例都均衡,按ABCABC这样的序列下去self.class_cycle = itertools.cycle(self.classes)# 依次迭代获取下一个类别self.curr_class = next(self.class_cycle)# 判断图片中是否含有采样器的当前类,为了实现样本均衡,没有就不处理了,有才处理def skip_sample_for_balanced_class(self, img_data):class_in_img = Falsefor bbox in img_data['bboxes']:cls_name = bbox['class']#只要图片中包含类别就够了,几个没关系if cls_name == self.curr_class:class_in_img = Trueself.curr_class = next(self.class_cycle)break# 包含了这个类别就可以处理,不包含就这个图片就没用了if class_in_img:return Falseelse:return True

这个样本选择器,主要是为了样本均衡的时候用的,他的目的就是为了保持样本均衡,要一个迭代器不停的迭代出样本的类别的序列,比如ABCABC…这样循环下去,以保证样本的比例是均衡的。skip_sample_for_balanced_class这个方法就是在筛选图片是否符合样本均衡的要求,具体在get_anchor_gt这个方法里会看到。如果我现在需要的是类别A的框,你图片里没有,那对不起,你这张图片我不要了,继续检查下一张,如果有,我才去处理。然后我继续迭代下一个需要的是类别B,继续检查图片。这样就强制实现了样本均衡,但是会丢掉很多不符合他类别序列顺序的样本了,其实不太合理,比如如果我的样本序列是AABBCC明显也是符合样本均衡的,但是强制那么多,就把一般的样本丢了,这样就浪费了,所以貌似这个样本均衡的机制也没启动,可以看到配置里是self.balanced_classes = False

好了,其实这个选择器没啥用,因为样本均衡没启动,但是我也讲一下这个干嘛用的,便于理解。接下来要讲get_anchor_gt这个方法了,怎么预处理标注框:

'''
获取真实的标注框信息
'''
def get_anchor_gt(all_img_data, class_count, C, img_length_calc_function, backend, mode='train'):''':param all_img_data: 所有的图片数据:param class_count: 类别数量的字典:param C: 配置:param img_length_calc_function: 特征图的尺寸:param backend: 后台是tf还是th:param mode: 是否训练:return:'''# The following line is not useful with Python 3.5, it is kept for the legacy# all_img_data = sorted(all_img_data)sample_selector = SampleSelector(class_count)while True:#训练的时候混洗一下if mode == 'train':np.random.shuffle(all_img_data)# 迭代所有的图片信息for img_data in all_img_data:try:# 是否要实现样本均衡,就是按照sample_selector迭代的序列进行样本的提取,否则就不要这个样本,# 比如样本迭代是A B C A B C... 如果图片中有这个类别的框,就处理,如果没有就不处理这个图片,直接看下一个图片了if C.balanced_classes and sample_selector.skip_sample_for_balanced_class(img_data):continue# read in image, and optionally add augmentationif mode == 'train':img_data_aug, x_img = data_augment.augment(img_data, C, augment=True)else:img_data_aug, x_img = data_augment.augment(img_data, C, augment=False)# 原始图像的宽高(width, height) = (img_data_aug['width'], img_data_aug['height'])(rows, cols, _) = x_img.shapeassert cols == widthassert rows == height# get image dimensions for resizing# 获取原图按照规定尺寸缩放后的宽高 默认是以最大600的长度,可以设置(resized_width, resized_height) = get_new_img_size(width, height, C.im_size)# resize the image so that smalles side is length = 600px# 将原图缩放到规定尺寸x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)try:# 计算RPN分类和回归y_rpn_cls, y_rpn_regr = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function)except:continue# Zero-center by mean pixel, and preprocess image# 更改维度顺序,转成RGB,cv默认是BGRx_img = x_img[:,:, (2, 1, 0)]  # BGR -> RGBx_img = x_img.astype(np.float32)# 做自定义的标准化x_img[:, :, 0] -= C.img_channel_mean[0]x_img[:, :, 1] -= C.img_channel_mean[1]x_img[:, :, 2] -= C.img_channel_mean[2]x_img /= C.img_scaling_factor# 转置 通道放最前面了x_img = np.transpose(x_img, (2, 0, 1)) # (3,600,1000)x_img = np.expand_dims(x_img, axis=0) # (1,3,600,1000)# 将回归误差后半部分误差值进行缩放y_rpn_regr[:, y_rpn_regr.shape[1]//2:, :, :] *= C.std_scaling# tf的话通道放最后if backend == 'tf':x_img = np.transpose(x_img, (0, 2, 3, 1))y_rpn_cls = np.transpose(y_rpn_cls, (0, 2, 3, 1))y_rpn_regr = np.transpose(y_rpn_regr, (0, 2, 3, 1))yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_augexcept Exception as e:print(e)continue

从头开始看,初始化样本选择器,其实没啥用,如果是训练就混洗图片数据,看一下图片数据的格式:
在这里插入图片描述
然后遍历所有的图片数据,如果开启了样本均衡,就要判断样本选择器是否选这个样本了,不选就直接遍历下一个样本了,这里没开启,所以也不用管,就处理样本就好了。如果训练集的话可能要进行数据增强,也就是data_augment.py里的augment方法,我们先来看看这个方法吧,不然上面的代码不好理解:

# 图片增强 翻转,旋转
def augment(img_data, config, augment=True):assert 'filepath' in img_dataassert 'bboxes' in img_dataassert 'width' in img_dataassert 'height' in img_data# 深拷贝,不然会修改原图img_data_aug = copy.deepcopy(img_data)# 图片信息 cv读出来的是BGRimg = cv2.imread(img_data_aug['filepath'])# 如果要进行数据增强的话,其实也就是旋转 翻转 然后更新一些信息if augment:# 高和宽rows, cols = img.shape[:2]# 水平翻转 50%概率if config.use_horizontal_flips and np.random.randint(0, 2) == 0:img = cv2.flip(img, 1)# 修正xfor bbox in img_data_aug['bboxes']:x1 = bbox['x1']x2 = bbox['x2']bbox['x2'] = cols - x1bbox['x1'] = cols - x2# 竖直翻转 50%概率if config.use_vertical_flips and np.random.randint(0, 2) == 0:img = cv2.flip(img, 0)# 修正yfor bbox in img_data_aug['bboxes']:y1 = bbox['y1']y2 = bbox['y2']bbox['y2'] = rows - y1bbox['y1'] = rows - y2# 旋转 顺时针,转置可以看成图片主对角线对称过来的样子if config.rot_90:angle = np.random.choice([0,90,180,270],1)[0]if angle == 270:img = np.transpose(img, (1,0,2))# 垂直翻转img = cv2.flip(img, 0)elif angle == 180:# 水平垂直翻转img = cv2.flip(img, -1)elif angle == 90:img = np.transpose(img, (1,0,2))# 水平翻转img = cv2.flip(img, 1)elif angle == 0:pass# 旋转后坐标修正for bbox in img_data_aug['bboxes']:x1 = bbox['x1']x2 = bbox['x2']y1 = bbox['y1']y2 = bbox['y2']if angle == 270:bbox['x1'] = y1bbox['x2'] = y2bbox['y1'] = cols - x2bbox['y2'] = cols - x1elif angle == 180:bbox['x2'] = cols - x1bbox['x1'] = cols - x2bbox['y2'] = rows - y1bbox['y1'] = rows - y2elif angle == 90:bbox['x1'] = rows - y2bbox['x2'] = rows - y1bbox['y1'] = x1bbox['y2'] = x2        elif angle == 0:pass# 旋转过后可能宽高有变化img_data_aug['width'] = img.shape[1]img_data_aug['height'] = img.shape[0]return img_data_aug, img

上面的代码也比较好理解,数据增强后,坐标肯定就变啦,具体可以自己画个图算算,光脑子想想不清楚,画个图就知道坐标怎么回事了,还有就是图片转置其实就是沿着颜色矩阵的主对角线进行翻转,然后配合图片本身的水平和竖直翻转就可以等价于角度的旋转,只是取了90,180,270这些比较好算的角度,否则就可能要进行复杂了。最后结果返回增强后的图片信息,和图片颜色信息。

然后我们继续看get_anchor_gt,后面获取了原始图片的高和宽,进行了缩放,把短边强制缩放成600,长边跟着比例缩放,可以看这个函数get_new_img_size比较简单不多说了,看代码就好了:

# 获得新的图片尺寸,短边长设置为600,等比例缩放比如500x300 变为 1000x600
def get_new_img_size(width, height, img_min_side=600):if width <= height:f = float(img_min_side) / widthresized_height = int(f * height)resized_width = img_min_sideelse:f = float(img_min_side) / heightresized_width = int(f * width)resized_height = img_min_sidereturn resized_width, resized_height

然后就用cv把图片给缩放了,之后我们要对图片真实数据进行RPN网络的分类和回归梯度的计算,主要是为了就是让标注数据处理成RPN输出的格式,好计算误差,用的是这个函数calc_rpn,因为这个方法比较复杂,所以我打算用新的篇章去讲。
在这里插入图片描述

好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,部分图片来自网络,侵删。

这篇关于Faster R-CNN Keras版源码史上最详细解读系列之RPN训练数据处理一的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/657218

相关文章

Redis与缓存解读

《Redis与缓存解读》文章介绍了Redis作为缓存层的优势和缺点,并分析了六种缓存更新策略,包括超时剔除、先删缓存再更新数据库、旁路缓存、先更新数据库再删缓存、先更新数据库再更新缓存、读写穿透和异步... 目录缓存缓存优缺点缓存更新策略超时剔除先删缓存再更新数据库旁路缓存(先更新数据库,再删缓存)先更新数

最新版IDEA配置 Tomcat的详细过程

《最新版IDEA配置Tomcat的详细过程》本文介绍如何在IDEA中配置Tomcat服务器,并创建Web项目,首先检查Tomcat是否安装完成,然后在IDEA中创建Web项目并添加Web结构,接着,... 目录配置tomcat第一步,先给项目添加Web结构查看端口号配置tomcat    先检查自己的to

使用Nginx来共享文件的详细教程

《使用Nginx来共享文件的详细教程》有时我们想共享电脑上的某些文件,一个比较方便的做法是,开一个HTTP服务,指向文件所在的目录,这次我们用nginx来实现这个需求,本文将通过代码示例一步步教你使用... 在本教程中,我们将向您展示如何使用开源 Web 服务器 Nginx 设置文件共享服务器步骤 0 —

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

SpringBoot集成SOL链的详细过程

《SpringBoot集成SOL链的详细过程》Solanaj是一个用于与Solana区块链交互的Java库,它为Java开发者提供了一套功能丰富的API,使得在Java环境中可以轻松构建与Solana... 目录一、什么是solanaj?二、Pom依赖三、主要类3.1 RpcClient3.2 Public

手把手教你idea中创建一个javaweb(webapp)项目详细图文教程

《手把手教你idea中创建一个javaweb(webapp)项目详细图文教程》:本文主要介绍如何使用IntelliJIDEA创建一个Maven项目,并配置Tomcat服务器进行运行,过程包括创建... 1.启动idea2.创建项目模板点击项目-新建项目-选择maven,显示如下页面输入项目名称,选择

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

在 VSCode 中配置 C++ 开发环境的详细教程

《在VSCode中配置C++开发环境的详细教程》本文详细介绍了如何在VisualStudioCode(VSCode)中配置C++开发环境,包括安装必要的工具、配置编译器、设置调试环境等步骤,通... 目录如何在 VSCode 中配置 C++ 开发环境:详细教程1. 什么是 VSCode?2. 安装 VSCo

Spring Boot 中整合 MyBatis-Plus详细步骤(最新推荐)

《SpringBoot中整合MyBatis-Plus详细步骤(最新推荐)》本文详细介绍了如何在SpringBoot项目中整合MyBatis-Plus,包括整合步骤、基本CRUD操作、分页查询、批... 目录一、整合步骤1. 创建 Spring Boot 项目2. 配置项目依赖3. 配置数据源4. 创建实体类

python与QT联合的详细步骤记录

《python与QT联合的详细步骤记录》:本文主要介绍python与QT联合的详细步骤,文章还展示了如何在Python中调用QT的.ui文件来实现GUI界面,并介绍了多窗口的应用,文中通过代码介绍... 目录一、文章简介二、安装pyqt5三、GUI页面设计四、python的使用python文件创建pytho