基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测

2023-11-03 05:21

本文主要是介绍基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
本项目首先对模型压缩领域中的知识蒸馏理论做了简单的介绍,然后基于PaddleDetection套件对目标检测知识蒸馏的最新方法(MGD)进行复现,对目标检测知识蒸馏的流程进行了细致的讲解。最后结果表明该方法具有较好的效果,可以明显提高学生模型的精度,甚至超越了教师模型。

模型mAP(IOU=0.5:0.95)AP(S)AP(M)AP(L)
teacher(retinanet-r101)41.345.133.156.5
student(retinanet-r50)40.034.933.144.1
distill(retinanet-r50+MGD)41.5(+1.5)46.2(+11.3)34.0(+0.9)50.1(+6.0)

如表格所示,蒸馏后涨点效果明显,大家可以按照本项目的流程对自己的数据集进行蒸馏训练。

一.项目背景

深度卷积神经网络以其突出的性能被广泛应用在目标检测任务上,然而庞大的模型参数和沉重的计算负担严重限制目标检测算法在移动机器人、车载摄像头等边缘设备上的应用,尤其在实时性要求较高的工业领域,过于复杂的模型必然会带来推理延时高的问题。随着深度学习技术的发展,采用知识蒸馏技术对模型进行压缩,可以实现知识迁移与网络精简。在人工智能逐步从理论研究走向大规模应用的背景下,如何利用知识蒸馏进行有效模型压缩已成为倍受关注且具有挑战性的研究热点。

1.1 基于logits的知识蒸馏

知识蒸馏最早是针对分类任务提出并广泛应用的,该方法以较小的精度损失为代价,将较大的教师模型的知识传递给较小的学生模型。2014年,Hinton等人首次提出基于logits的蒸馏方法,该论文给出了知识蒸馏的明确定义,即——将大模型或集成模型中的“暗知识”通过蒸馏的方式,迁移到小模型中,以达到缩小模型或提高精确度的目的。

1.2 基于特征的知识蒸馏

当把知识蒸馏直接应用于目标检测任务上时,目标区域的差异性会被淹没在过多的非目标区域中(背景),使得优化目标被掩盖,模型难以收敛,传统知识蒸馏方法不再行之有效。于是目前在目标检测上主要使用的是基于特征匹配的知识蒸馏,即别提取教师和学生网络Backbone或neck层的特征图,让学生模型模仿教师模型的特征图,从而优化学生模型的表现。

1.3 掩蔽生成知识蒸馏(Masked Generative Distillation)

MGD是ECCV 2022关于知识蒸馏的论文: Masked Generative Distillation所提出的方法,方法适用于分类,检测与分割任务。作者认为提升学生的表征能力并不一定需要通过直接模仿教师实现。从这点出发,把模仿任务修改成了生成任务:让学生凭借自己较弱的特征去生成教师较强的特征。在蒸馏过程中,对学生特征进行了随机mask,强制学生仅用自己的部分特征去生成教师的所有特征,以提升学生的表征能力。整体架构如下图所示:

论文在COCO2017上使用RetinaNet(ResNeXt101)蒸馏RetinaNet(Res50)的结果如下:

二. 数据集介绍

原论文使用的是COCO2017数据集,由于算力成本以及时间限制,本项目使用的是由东北大学(NEU)发布的钢铁表面缺陷数据集,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc),每种缺陷类别300张。下图为六种典型表面缺陷的示例,每幅图像的分辨率为200 * 200像素,本项目中挂载的数据集已按照7:2:1的比例划分好。

三. 算法实现

得益于PaddleDetection的模块化设计,本项目实现了MGD算法。在PaddleDetection/ppdet/slim/distill.py中创建MGDDistillModel类MGDFeatureLoss类,目前仅支持retinanet模型之间进行蒸馏。

class MGDDistillModel(nn.Layer):"""Build MGD distill model.Args:cfg: The student config.slim_cfg: The teacher and distill config."""def __init__(self, cfg, slim_cfg):super(MGDDistillModel, self).__init__()self.is_inherit = True# build student model before load slim configself.student_model = create(cfg.architecture)self.arch = cfg.architecturestu_pretrain = cfg['pretrain_weights']slim_cfg = load_config(slim_cfg)self.teacher_cfg = slim_cfgself.loss_cfg = slim_cfgtea_pretrain = cfg['pretrain_weights']self.teacher_model = create(self.teacher_cfg.architecture)self.teacher_model.eval()for param in self.teacher_model.parameters():param.trainable = Falseif 'pretrain_weights' in cfg and stu_pretrain:if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:load_pretrain_weight(self.student_model,self.teacher_cfg.pretrain_weights)logger.debug("Inheriting! loading teacher weights to student model!")load_pretrain_weight(self.student_model, stu_pretrain)if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:load_pretrain_weight(self.teacher_model,self.teacher_cfg.pretrain_weights)self.mgd_loss_dic = self.build_loss(self.loss_cfg.distill_loss,name_list=self.loss_cfg['distill_loss_name'])def build_loss(self,cfg,name_list=['neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1','neck_f_0']):loss_func = dict()for idx, k in enumerate(name_list):loss_func[k] = create(cfg)return loss_funcdef forward(self, inputs):if self.training:s_body_feats = self.student_model.backbone(inputs)s_neck_feats = self.student_model.neck(s_body_feats)with paddle.no_grad():t_body_feats = self.teacher_model.backbone(inputs)t_neck_feats = self.teacher_model.neck(t_body_feats)loss_dict = {}for idx, k in enumerate(self.mgd_loss_dic):loss_dict[k] = self.mgd_loss_dic[k](s_neck_feats[idx],t_neck_feats[idx])if self.arch == "RetinaNet":loss = self.student_model.head(s_neck_feats, inputs)elif self.arch == "PicoDet":head_outs = self.student_model.head(s_neck_feats, self.student_model.export_post_process)loss_gfl = self.student_model.head.get_loss(head_outs, inputs)total_loss = paddle.add_n(list(loss_gfl.values()))loss = {}loss.update(loss_gfl)loss.update({'loss': total_loss})else:raise ValueError(f"Unsupported model {self.arch}")for k in loss_dict:loss['loss'] += loss_dict[k]loss[k] = loss_dict[k]return losselse:body_feats = self.student_model.backbone(inputs)neck_feats = self.student_model.neck(body_feats)head_outs = self.student_model.head(neck_feats)if self.arch == "RetinaNet":bbox, bbox_num = self.student_model.head.post_process(head_outs, inputs['im_shape'], inputs['scale_factor'])return {'bbox': bbox, 'bbox_num': bbox_num}elif self.arch == "PicoDet":head_outs = self.student_model.head(neck_feats, self.student_model.export_post_process)scale_factor = inputs['scale_factor']bboxes, bbox_num = self.student_model.head.post_process(head_outs,scale_factor,export_nms=self.student_model.export_nms)return {'bbox': bboxes, 'bbox_num': bbox_num}else:raise ValueError(f"Unsupported model {self.arch}")
@register
class MGDFeatureLoss(nn.Layer):"""Paddle version of `Masked Generative Distillation`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. name (str): the loss name of the layeralpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002lambda_mgd (float, optional): masked ratio. Defaults to 0.65"""def __init__(self,student_channels=256,teacher_channels=256,alpha_mgd=0.00002,lambda_mgd=0.65,):super(MGDFeatureLoss, self).__init__()self.alpha_mgd = alpha_mgdself.lambda_mgd = lambda_mgdif student_channels != teacher_channels:self.align = nn.Conv2D(student_channels,teacher_channels,kernel_size=1,stride=1,padding=0)student_channels = teacher_channelselse:self.align = Noneself.generation = nn.Sequential(nn.Conv2D(teacher_channels, teacher_channels, kernel_size=3, padding=1),nn.ReLU(), nn.Conv2D(teacher_channels, teacher_channels, kernel_size=3, padding=1))def forward(self,preds_S,preds_T):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature map"""assert preds_S.shape[-2:] == preds_T.shape[-2:]if self.align is not None:preds_S = self.align(preds_S)loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgdreturn lossdef get_dis_loss(self, preds_S, preds_T):N, C, H, W = preds_T.shapemat = paddle.rand((N,1,H,W))mat = paddle.where(mat>1-self.lambda_mgd, 0, 1)mat=paddle.cast(mat,'float32')masked_fea = paddle.multiply(preds_S, mat)new_fea = self.generation(masked_fea)dis_loss = F.mse_loss(new_fea, preds_T,reduction="sum")/Nreturn dis_loss

此外,还需要在slim文件夹下的_init_.py中作如下修改,以通过配置文件来创建蒸馏模型。

    if slim_load_cfg['slim'] == 'Distill':if "slim_method" in slim_load_cfg and slim_load_cfg['slim_method'] == "FGD":model = FGDDistillModel(cfg, slim_cfg)elif "slim_method" in slim_load_cfg and slim_load_cfg['slim_method'] == "MGD":model = MGDDistillModel(cfg, slim_cfg)

四. 环境配置

4.1 解压数据集

!tar -zxvf /home/aistudio/data/data218435/NEU-DET-COCO.tar.gz -C /home/aistudio/data/

4.2 下载PaddleDetection并安装依赖项

从github拉取PaddleDetection,或者在左侧的套件管理中直接快速下载PaddleDetection-2.5,下载完毕需要重命名文件夹为PaddleDetection。

!git clone -b release/2.5  https://github.com/PaddlePaddle/PaddleDetection.git

将下列文件拷贝到PaddleDetection中

!cp -r work/demo work/output work/ppdet PaddleDetection/
%cd PaddleDetection
!pip install -r requirements.txt

4.3 从源码编译安装PaddleDetection

后续若对源码进行改动,务必再次执行下列命令重新编译

!python setup.py install

五. 开始训练

5.1和5.2主要内容是训练教师模型和学生模型,若已有训练好的模型,可直接跳到5.3开始蒸馏训练

5.1 训练教师模型

这里选择的是retinanet_r101_fpn作为教师模型,训练时通过加载PaddleDetection官方在coco上训练好的模型作为预训练模型,再微调训练三十几个epoch即可达到收敛。按照如下图所示修改retinanet_r101_fpn_2x_coco.yml配置文件,为方便操作,本项目把所有需要用到的配置文件全都放到了根目录中,后续修改好所需的文件后,通过命令一键导入所有配置到相应位置

导入coco预训练模型后,需在optimizer_2x.yml配置文件中降低学习率,这里把原本的0.01降低了10倍

其余的配置文件按照自己需要修改即可,修改完毕使用下面的命令一键导入

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r101_fpn.yml configs/retinanet/_base_/retinanet_r101_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r101_fpn_2x_coco.yml configs/retinanet/retinanet_r101_fpn_2x_coco.yml

执行下面的命令开始训练

!python tools/train.py -c configs/retinanet/retinanet_r101_fpn_2x_coco.yml --use_vdl=True --vdl_log_dir=./teacher/retinanet_r101/ --eval 

我训练的教师模型验证集mAP最高为0.413

5.2 训练学生模型

单独训练学生模型的目的是为了与蒸馏训练后的模型进行对比,与训练教师模型类似,加载预训练模型和修改学习率,这里不再赘述

导入配置文件

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r50_fpn.yml configs/retinanet/_base_/retinanet_r50_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r50_fpn_2x_coco.yml configs/retinanet/retinanet_r50_fpn_2x_coco.yml

开始训练

!python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --use_vdl=True --vdl_log_dir=./student/retinanet_r50/ --eval 

训练的学生模型验证集mAP最高为0.40

5.3 蒸馏训练

修改蒸馏的配置文件retinanet_resnet101_coco_mgd_distill.yml,pretrain_weights路径选择训练好的教师模型路径

导入配置文件

!cp ../runtime.yml configs/runtime.yml
!cp ../coco_detection.yml configs/datasets/coco_detection.yml
!cp ../retinanet_r50_fpn.yml configs/retinanet/_base_/retinanet_r50_fpn.yml
!cp ../optimizer_2x.yml configs/retinanet/_base_/optimizer_2x.yml
!cp ../retinanet_reader.yml configs/retinanet/_base_/retinanet_reader.yml
!cp ../retinanet_r50_fpn_2x_coco.yml configs/retinanet/retinanet_r50_fpn_2x_coco.yml
!cp ../retinanet_resnet101_coco_mgd_distill.yml configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml

单卡训练

!python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml \
--use_vdl=True --vdl_log_dir=./distill/retinanet_r50/ --eval 

多卡训练时需要注意batch_size的大小,将retinanet_reader.yml中的batch_size修改为2,这样四卡总的batch_size还是8

!python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3 tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_mgd_distill.yml \
--use_vdl=True --vdl_log_dir=./distill/retinanet_r50/ --eval

使用MGD方法进行蒸馏训练后的模型结果,可以看出各项指标都有明显提升

5.4 模型评估

!python tools/eval.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=output/retinanet_resnet101_coco_mgd_distill/best_model.pdparams

5.5 推理预测

!python tools/infer.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=output/retinanet_resnet101_coco_mgd_distill/best_model.pdparams --infer_img=demo/patches_237.jpg

从测试集中选取了几张图片,预测结果如下,可以看出效果还是不错的





六. 总结

本项目对模型压缩领域中的知识蒸馏做了简单的介绍,并基于PaddleDetection套件对目标检测知识蒸馏的最新方法(MGD)进行复现,结果表明该方法具有较好的效果,可以明显提高学生模型的精度,甚至超越了教师模型,尤其是对小目标的检测提升较大。使用较小的模型却能获得接近甚至超越更复杂模型的性能,这就是知识蒸馏的意义所在。

模型mAP(IOU=0.5:0.95)AP(S)AP(M)AP(L)
teacher(retinanet-r101)41.345.133.156.5
student(retinanet-r50)40.034.933.144.1
distill(retinanet-r50+MGD)41.5(+1.5)46.2(+11.3)34.0(+0.9)50.1(+6.0)

后续工作:

  • 目前仅支持retinanet,考虑增加适配更多模型,如PPYOLOE。

此文章为搬运
原项目链接

这篇关于基于掩蔽生成知识蒸馏(MGD)的钢铁表面缺陷检测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/335875

相关文章

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

C/C++随机数生成的五种方法

《C/C++随机数生成的五种方法》C++作为一种古老的编程语言,其随机数生成的方法已经经历了多次的变革,早期的C++版本使用的是rand()函数和RAND_MAX常量,这种方法虽然简单,但并不总是提供... 目录C/C++ 随机数生成方法1. 使用 rand() 和 srand()2. 使用 <random

Flask 验证码自动生成的实现示例

《Flask验证码自动生成的实现示例》本文主要介绍了Flask验证码自动生成的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习... 目录生成图片以及结果处理验证码蓝图html页面展示想必验证码大家都有所了解,但是可以自己定义图片验证码

Python如何在Word中生成多种不同类型的图表

《Python如何在Word中生成多种不同类型的图表》Word文档中插入图表不仅能直观呈现数据,还能提升文档的可读性和专业性,本文将介绍如何使用Python在Word文档中创建和自定义各种图表,需要的... 目录在Word中创建柱形图在Word中创建条形图在Word中创建折线图在Word中创建饼图在Word

国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)

《国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)》本文给大家利用deepseek模型搭建私有知识问答库的详细步骤和遇到的问题及解决办法,感兴趣的朋友一起看看吧... 目录1. 第1步大家在安装完ollama后,需要到系统环境变量中添加两个变量2. 第3步 “在cmd中

nginx生成自签名SSL证书配置HTTPS的实现

《nginx生成自签名SSL证书配置HTTPS的实现》本文主要介绍在Nginx中生成自签名SSL证书并配置HTTPS,包括安装Nginx、创建证书、配置证书以及测试访问,具有一定的参考价值,感兴趣的可... 目录一、安装nginx二、创建证书三、配置证书并验证四、测试一、安装nginxnginx必须有"-

Java实战之利用POI生成Excel图表

《Java实战之利用POI生成Excel图表》ApachePOI是Java生态中处理Office文档的核心工具,这篇文章主要为大家详细介绍了如何在Excel中创建折线图,柱状图,饼图等常见图表,需要的... 目录一、环境配置与依赖管理二、数据源准备与工作表构建三、图表生成核心步骤1. 折线图(Line Ch