基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测

2023-11-03 05:21

本文主要是介绍基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
本项目首先对模型压缩领域中的知识蒸馏理论做了简单的介绍,然后基于PaddleDetection套件对目标检测知识蒸馏的最新方法(MGD)进行复现,对目标检测知识蒸馏的流程进行了细致的讲解。最后结果表明该方法具有较好的效果,可以明显提高学生模型的精度,甚至超越了教师模型。

模型mAP(IOU=0.5:0.95)AP(S)AP(M)AP(L)
teacher(retinanet-r101)41.345.133.156.5
student(retinanet-r50)40.034.933.144.1
distill(retinanet-r50+MGD)41.5(+1.5)46.2(+11.3)34.0(+0.9)50.1(+6.0)

如表格所示,蒸馏后涨点效果明显,大家可以按照本项目的流程对自己的数据集进行蒸馏训练。

一.项目背景

深度卷积神经网络以其突出的性能被广泛应用在目标检测任务上,然而庞大的模型参数和沉重的计算负担严重限制目标检测算法在移动机器人、车载摄像头等边缘设备上的应用,尤其在实时性要求较高的工业领域,过于复杂的模型必然会带来推理延时高的问题。随着深度学习技术的发展,采用知识蒸馏技术对模型进行压缩,可以实现知识迁移与网络精简。在人工智能逐步从理论研究走向大规模应用的背景下,如何利用知识蒸馏进行有效模型压缩已成为倍受关注且具有挑战性的研究热点。

1.1 基于logits的知识蒸馏

知识蒸馏最早是针对分类任务提出并广泛应用的,该方法以较小的精度损失为代价,将较大的教师模型的知识传递给较小的学生模型。2014年,Hinton等人首次提出基于logits的蒸馏方法,该论文给出了知识蒸馏的明确定义,即——将大模型或集成模型中的“暗知识”通过蒸馏的方式,迁移到小模型中,以达到缩小模型或提高精确度的目的。

1.2 基于特征的知识蒸馏

当把知识蒸馏直接应用于目标检测任务上时,目标区域的差异性会被淹没在过多的非目标区域中(背景),使得优化目标被掩盖,模型难以收敛,传统知识蒸馏方法不再行之有效。于是目前在目标检测上主要使用的是基于特征匹配的知识蒸馏,即别提取教师和学生网络Backbone或neck层的特征图,让学生模型模仿教师模型的特征图,从而优化学生模型的表现。

1.3 掩蔽生成知识蒸馏(Masked Generative Distillation)

MGD是ECCV 2022关于知识蒸馏的论文: Masked Generative Distillation所提出的方法,方法适用于分类,检测与分割任务。作者认为提升学生的表征能力并不一定需要通过直接模仿教师实现。从这点出发,把模仿任务修改成了生成任务:让学生凭借自己较弱的特征去生成教师较强的特征。在蒸馏过程中,对学生特征进行了随机mask,强制学生仅用自己的部分特征去生成教师的所有特征,以提升学生的表征能力。整体架构如下图所示:

论文在COCO2017上使用RetinaNet(ResNeXt101)蒸馏RetinaNet(Res50)的结果如下:

二. 数据集介绍

原论文使用的是COCO2017数据集,由于算力成本以及时间限制,本项目使用的是由东北大学(NEU)发布的钢铁表面缺陷数据集,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc),每种缺陷类别300张。下图为六种典型表面缺陷的示例,每幅图像的分辨率为200 * 200像素,本项目中挂载的数据集已按照7:2:1的比例划分好。

三. 算法实现

得益于PaddleDetection的模块化设计,本项目实现了MGD算法。在PaddleDetection/ppdet/slim/distill.py中创建MGDDistillModel类MGDFeatureLoss类,目前仅支持retinanet模型之间进行蒸馏。

class MGDDistillModel(nn.Layer):"""Build MGD distill model.Args:cfg: The student config.slim_cfg: The teacher and distill config."""def __init__(self, cfg, slim_cfg):super(MGDDistillModel, self).__init__()self.is_inherit = True# build student model before load slim configself.student_model = create(cfg.architecture)self.arch = cfg.architecturestu_pretrain = cfg['pretrain_weights']slim_cfg = load_config(slim_cfg)self.teacher_cfg = slim_cfgself.loss_cfg = slim_cfgtea_pretrain = cfg['pretrain_weights']self.teacher_model = create(self.teacher_cfg.architecture)self.teacher_model.eval()for param in self.teacher_model.parameters():param.trainable = Falseif 'pretrain_weights' in cfg and stu_pretrain:if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:load_pretrain_weight(self.student_model,self.teacher_cfg.pretrain_weights)logger.debug("Inheriting! loading teacher weights to student model!")load_pretrain_weight(self.student_model, stu_pretrain)if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:load_pretrain_weight(self.teacher_model,self.teacher_cfg.pretrain_weights)self.mgd_loss_dic = self.build_loss(self.loss_cfg.distill_loss,name_list=self.loss_cfg['distill_loss_name'])def build_loss(self,cfg,name_list=['neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1','neck_f_0']):loss_func = dict()for idx, k in enumerate(name_list):loss_func[k] = create(cfg)return loss_funcdef forward(self, inputs):if self.training:s_body_feats = self.student_model.backbone(inputs)s_neck_feats = self.student_model.neck(s_body_feats)with paddle.no_grad():t_body_feats = self.teacher_model.backbone(inputs)t_neck_feats = self.teacher_model.neck(t_body_feats)loss_dict = {}for idx, k in enumerate(self.mgd_loss_dic):loss_dict[k] = self.mgd_loss_dic[k](s_neck_feats[idx],t_neck_feats[idx])if self.arch == "RetinaNet":loss = self.student_model.head(s_neck_feats, inputs)elif self.arch == "PicoDet":head_outs = self.student_model.head(s_neck_feats, self.student_model.export_post_process)loss_gfl = self.student_model.head.get_loss(head_outs, inputs)total_loss = paddle.add_n(list(loss_gfl.values()))loss = {}loss.update(loss_gfl)loss.update({'loss': total_loss})else:raise ValueError(f"Unsupported model {self.arch}")for k in loss_dict:loss['loss'] += loss_dict[k]loss[k] = loss_dict[k]return losselse:body_feats = self.student_model.backbone(inputs)neck_feats = self.student_model.neck(body_feats)head_outs = self.student_model.head(neck_feats)if self.arch == "RetinaNet":bbox, bbox_num = self.student_model.head.post_process(head_outs, inputs['im_shape'], inputs['scale_factor'])return {'bbox': bbox, 'bbox_num': bbox_num}elif self.arch == "PicoDet":head_outs = self.student_model.head(neck_feats, self.student_model.export_post_process)scale_factor = inputs['scale_factor']bboxes, bbox_num = self.student_model.head.post_process(head_outs,scale_factor,export_nms=self.student_model.export_nms)return {'bbox': bboxes, 'bbox_num': bbox_num}else:raise ValueError(f"Unsupported model {self.arch}")
@register
class MGDFeatureLoss(nn.Layer):"""Paddle version of `Masked Generative Distillation`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. name (str): the loss name of the layeralpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002lambda_mgd (float, optional): masked ratio. Defaults to 0.65"""def __init__(self,student_channels=256,teacher_channels=256,alpha_mgd=0.00002,lambda_mgd=0.65,):super(MGDFeatureLoss, self).__init__()self.alpha_mgd = alpha_mgdself.lambda_mgd = lambda_mgdif student_channels != teacher_channels:self.align = nn.Conv2D(student_channels,teacher_channels,kernel_size=1,stride=1,padding=0)student_channels = teacher_channelselse:self.align = Noneself.generation = nn.Sequential(nn.Conv2D(teacher_channels, teacher_channels, kernel_size=3, padding=1),nn.ReLU(), nn.Conv2D(teacher_channels, teacher_channels, kernel_size=3, padding=1))def forward(self,preds_S,preds_T):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature map"""assert preds_S.shape[-2:] == preds_T.shape[-2:]if self.align is not None:preds_S = self.align(preds_S)loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgdreturn lossdef get_dis_loss(self, preds_S, preds_T):N, C, H, W = preds_T.shapemat = paddle.rand((N,1,H,W))mat = paddle.where(mat>1-self.lambda_mgd, 0, 1)mat=paddle.cast(mat,'float32')masked_fea = paddle.multiply(preds_S, mat)new_fea = self.generation(masked_fea)dis_loss = F.mse_loss(new_fea, preds_T,reduction="sum")/Nreturn dis_loss

此外,还需要在slim文件夹下的_init_.py中作如下修改,以通过配置文件来创建蒸馏模型。

    if slim_load_cfg['slim'] == 'Distill':if "slim_method" in slim_load_cfg and slim_load_cfg['slim_method'] == "FGD":model = FGDDistillModel(cfg, slim_cfg)elif "slim_method" in slim_load_cfg and slim_load_cfg['slim_method'] == "MGD":model = MGDDistillModel(cfg, slim_cfg)

四. 环境配置

4.1 解压数据集

!tar -zxvf /home/aistudio/data/data218435/NEU-DET-COCO.tar.gz -C /home/aistudio/data/

4.2 下载PaddleDetection并安装依赖项

从github拉取PaddleDetection,或者在左侧的套件管理中直接快速下载PaddleDetection-2.5,下载完毕需要重命名文件夹为PaddleDetection。

!git clone -b release/2.5  https://github.com/PaddlePaddle/PaddleDetection.git

将下列文件拷贝到PaddleDetection中

!cp -r work/demo work/output work/ppdet PaddleDetection/
%cd PaddleDetection
!pip install -r requirements.txt

4.3 从源码编译安装PaddleDetection

后续若对源码进行改动,务必再次执行下列命令重新编译

!python setup.py install

五. 开始训练

5.1和5.2主要内容是训练教师模型和学生模型,若已有训练好的模型,可直接跳到5.3开始蒸馏训练

5.1 训练教师模型

这里选择的是retinanet_r101_fpn作为教师模型,训练时通过加载PaddleDetection官方在coco上训练好的模型作为预训练模型,再微调训练三十几个epoch即可达到收敛。按照如下图所示修改retinanet_r101_fpn_2x_coco.yml配置文件,为方便操作,本项目把所有需要用到的配置文件全都放到了根目录中,后续修改好所需的文件后,通过命令一键导入所有配置到相应位置

导入coco预训练模型后,需在optimizer_2x.yml配置文件中降低学习率,这里把原本的0.01降低了10倍

其余的配置文件按照自己需要修改即可,修改完毕使用下面的命令一键导入

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r101_fpn.yml configs/retinanet/_base_/retinanet_r101_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r101_fpn_2x_coco.yml configs/retinanet/retinanet_r101_fpn_2x_coco.yml

执行下面的命令开始训练

!python tools/train.py -c configs/retinanet/retinanet_r101_fpn_2x_coco.yml --use_vdl=True --vdl_log_dir=./teacher/retinanet_r101/ --eval 

我训练的教师模型验证集mAP最高为0.413

5.2 训练学生模型

单独训练学生模型的目的是为了与蒸馏训练后的模型进行对比,与训练教师模型类似,加载预训练模型和修改学习率,这里不再赘述

导入配置文件

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r50_fpn.yml configs/retinanet/_base_/retinanet_r50_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r50_fpn_2x_coco.yml configs/retinanet/retinanet_r50_fpn_2x_coco.yml

开始训练

!python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --use_vdl=True --vdl_log_dir=./student/retinanet_r50/ --eval 

训练的学生模型验证集mAP最高为0.40

5.3 蒸馏训练

修改蒸馏的配置文件retinanet_resnet101_coco_mgd_distill.yml,pretrain_weights路径选择训练好的教师模型路径

导入配置文件

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r50_fpn.yml configs/retinanet/_base_/retinanet_r50_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r50_fpn_2x_coco.yml configs/retinanet/retinanet_r50_fpn_2x_coco.yml
!cp ../retinanet_resnet101_coco_mgd_distill.yml configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml

单卡训练

!python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml \
--use_vdl=True --vdl_log_dir=./distill/retinanet_r50/ --eval 

多卡训练时需要注意batch_size的大小,将retinanet_reader.yml中的batch_size修改为2,这样四卡总的batch_size还是8

!python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3 tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml \
--use_vdl=True --vdl_log_dir=./distill/retinanet_r50/ --eval

使用MGD方法进行蒸馏训练后的模型结果,可以看出各项指标都有明显提升

5.4 模型评估

!python tools/eval.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=output/retinanet_resnet101_coco_mgd_distill/best_model.pdparams

5.5 推理预测

!python tools/infer.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=output/retinanet_resnet101_coco_mgd_distill/best_model.pdparams --infer_img=demo/patches_237.jpg

从测试集中选取了几张图片,预测结果如下,可以看出效果还是不错的





六. 总结

本项目对模型压缩领域中的知识蒸馏做了简单的介绍,并基于PaddleDetection套件对目标检测知识蒸馏的最新方法(MGD)进行复现,结果表明该方法具有较好的效果,可以明显提高学生模型的精度,甚至超越了教师模型,尤其是对小目标的检测提升较大。使用较小的模型却能获得接近甚至超越更复杂模型的性能,这就是知识蒸馏的意义所在。

模型mAP(IOU=0.5:0.95)AP(S)AP(M)AP(L)
teacher(retinanet-r101)41.345.133.156.5
student(retinanet-r50)40.034.933.144.1
distill(retinanet-r50+MGD)41.5(+1.5)46.2(+11.3)34.0(+0.9)50.1(+6.0)

后续工作:

  • 目前仅支持retinanet,考虑增加适配更多模型,如PPYOLOE。

此文章为搬运
原项目链接

这篇关于基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/335875

相关文章

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Python脚本轻松实现检测麦克风功能

《Python脚本轻松实现检测麦克风功能》在进行音频处理或开发需要使用麦克风的应用程序时,确保麦克风功能正常是非常重要的,本文将介绍一个简单的Python脚本,能够帮助我们检测本地麦克风的功能,需要的... 目录轻松检测麦克风功能脚本介绍一、python环境准备二、代码解析三、使用方法四、知识扩展轻松检测麦

C#使用Spire.XLS快速生成多表格Excel文件

《C#使用Spire.XLS快速生成多表格Excel文件》在日常开发中,我们经常需要将业务数据导出为结构清晰的Excel文件,本文将手把手教你使用Spire.XLS这个强大的.NET组件,只需几行C#... 目录一、Spire.XLS核心优势清单1.1 性能碾压:从3秒到0.5秒的质变1.2 批量操作的优雅

Python使用python-pptx自动化操作和生成PPT

《Python使用python-pptx自动化操作和生成PPT》这篇文章主要为大家详细介绍了如何使用python-pptx库实现PPT自动化,并提供实用的代码示例和应用场景,感兴趣的小伙伴可以跟随小编... 目录使用python-pptx操作PPT文档安装python-pptx基础概念创建新的PPT文档查看

在ASP.NET项目中如何使用C#生成二维码

《在ASP.NET项目中如何使用C#生成二维码》二维码(QRCode)已广泛应用于网址分享,支付链接等场景,本文将以ASP.NET为示例,演示如何实现输入文本/URL,生成二维码,在线显示与下载的完整... 目录创建前端页面(Index.cshtml)后端二维码生成逻辑(Index.cshtml.cs)总结

Python实现数据可视化图表生成(适合新手入门)

《Python实现数据可视化图表生成(适合新手入门)》在数据科学和数据分析的新时代,高效、直观的数据可视化工具显得尤为重要,下面:本文主要介绍Python实现数据可视化图表生成的相关资料,文中通过... 目录前言为什么需要数据可视化准备工作基本图表绘制折线图柱状图散点图使用Seaborn创建高级图表箱线图热

SQLServer中生成雪花ID(Snowflake ID)的实现方法

《SQLServer中生成雪花ID(SnowflakeID)的实现方法》:本文主要介绍在SQLServer中生成雪花ID(SnowflakeID)的实现方法,文中通过示例代码介绍的非常详细,... 目录前言认识雪花ID雪花ID的核心特点雪花ID的结构(64位)雪花ID的优势雪花ID的局限性雪花ID的应用场景

Django HTTPResponse响应体中返回openpyxl生成的文件过程

《DjangoHTTPResponse响应体中返回openpyxl生成的文件过程》Django返回文件流时需通过Content-Disposition头指定编码后的文件名,使用openpyxl的sa... 目录Django返回文件流时使用指定文件名Django HTTPResponse响应体中返回openp