03 U2net

2023-11-20 18:50
文章标签 03 u2net

本文主要是介绍03 U2net,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

一、 理论知识

1. 网络架构

2. RSU

3. 显著特征融合模块

4. 损失计算​编辑

二、 代码实现

0. DUTS数据集

1. transforms.ToTensor()

2. train

3. 验证

4. 预测


原文链接https://blog.csdn.net/qq_37541097/article/details/126255483

一、 理论知识

U2Net是针对Salient Object Detetion(SOD)即显著性目标检测任务。

显著性目标检测任务与语义分割任务相似,只不过显著性目标检测任务是二分类任务,它的任务是将图片中最吸引人的目标或区域分割出来,故只有前景和背景两类。

1. 网络架构

在大的UNet中嵌入了一堆小UNet

2. RSU

En_1En_2En_3En_4De_1De_2De_3De_4采用的是同一种Block,只是深度不同。

Block就是论文中提出的ReSidual U-block简称RSU(也就是小Unet)

  • En_1De_1采用的是RSU-7En_2De_2采用的是RSU-6En_3De_3采用的是RSU-5En_4De_4采用的是RSU-4
  • En_5En_6De_5三个模块采用的是RSU-4FRSU-4FRSU-4两者结构并不相同
  • 带参数d的卷积层全部是膨胀卷积,d为膨胀系数

深度为7的Block ---------- RSU-7

RSU-4F将采样层全部替换成了膨胀卷积:

3. 显著特征融合模块

saliency map fusion module即显著特征融合模块 ---------- 最终输出模块

对照1中网络结构图

  • 收集De_1、De_2、De_3、De_4、De_5以及En_6的输出;
  • 分别通过一个3x3的卷积层得到channel为1的特征图;
  • 双线性插值缩放到输入图片大小得到Sup1、Sup2、Sup3、Sup4、Sup5和Sup6;
  • 6个特征图进行Concat拼接
  • 通过一个1x1的卷积层以及Sigmiod激活函数得到最终的预测概率图

4. 损失计算

 评价指标:

二、 代码实现

显著性目标检测只区分背景和前景,本项目使用的数据集是DUTS数据集

target:黑白图二维(H,W)------- 背景:0 ;   前景:255 ;  边缘地方:小的数值

U2Net没有torch官方实现;需要使用时修改项目中的代码

0. DUTS数据集

目录结构如下:

├── DUTS-TR
│      ├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
│      └── DUTS-TR-Mask: 该文件夹存放对应训练图片的GT标签(Mask蒙板形式)
│
└── DUTS-TE├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片└── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的GT标签(Mask蒙板形式)

mask只区分前景背景:(如下图)

 

该数据集没有用到调色板!!!!!

自定义DUTS数据集,详见my_dataset.py文件

1. transforms.ToTensor()

针对image:

  • 转变形状,将H,W,C转变成C,H,W
  • 转变数据格式,返回的tensor格式
  • 归一化,可以将数据的范围变成[0,1]----所有像素值÷255

针对target:

分割类任务,target是二维(H,W)

  • 经过.ToTensor()后,自动增加维度,变为(1,H,W)
  • 转变数据格式,返回的tensor格式
  • 归一化,可以将数据的范围变成[0,1]----所有像素值÷255

2. train

由于torch官方没有实现该模型,加载作者训练好的.pth文件

# ********************************************* 加载.pth参数
# 建立model
model = u2net_full()
# 查看model的初始化参数值
old_dic = model.state_dict()
# 加载.pth参数文件
weight_path = './u2net_full.pth'
pre_dic = torch.load(weight_path, map_location=device)
# 将.pth参数加载到model中,只会将字典名称完全一样以及shape相同的加载进去
# 返回model中未加载成功的参数   以及   .pth中多余的参数(与model不匹配)------strict=False
missing_keys, unexpected_keys = model.load_state_dict(pre_dic, strict=False)
print(missing_keys, unexpected_keys)
# 再次查看model的参数值,检查是否已经更换成功
new_dic = model.state_dict()
model.to(device)

损失计算:

# inputs就是model预测的结果-----多个特征图查看原理部分损失计算
def criterion(inputs, target):losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]total_loss = sum(losses)  # 求和return total_loss

3. 验证

需要建立两个评价指标-------  MeanAbsoluteError  和  F1Score

mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
...
...
def evaluate(,,,,):model.eval()mae_metric = utils.MeanAbsoluteError()f1_metric = utils.F1Score()...with torch.no_grad():for images, targets in data_loader:images, targets = images.to(device), targets.to(device)output = model(images)mae_metric.update(output, targets)f1_metric.update(output, targets)return mae_metric, f1_metric
class MeanAbsoluteError(object):def __init__(self):self.mae_list = []def update(self, pred: torch.Tensor, gt: torch.Tensor):batch_size, c, h, w = gt.shapeassert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)error_pixels = torch.sum(torch.abs(resize_pred - gt), dim=(1, 2, 3)) / (h * w)self.mae_list.extend(error_pixels.tolist())def compute(self):mae = sum(self.mae_list) / len(self.mae_list)return maedef gather_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()gather_mae_list = []for i in all_gather(self.mae_list):gather_mae_list.extend(i)self.mae_list = gather_mae_listdef __str__(self):mae = self.compute()return f'MAE: {mae:.3f}'class F1Score(object):"""refer: https://github.com/xuebinqin/DIS/blob/main/IS-Net/basics.py"""def __init__(self, threshold: float = 0.5):self.precision_cum = Noneself.recall_cum = Noneself.num_cum = Noneself.threshold = thresholddef update(self, pred: torch.Tensor, gt: torch.Tensor):batch_size, c, h, w = gt.shapeassert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)gt_num = torch.sum(torch.gt(gt, self.threshold).float())pp = resize_pred[torch.gt(gt, self.threshold)]  # 对应预测map中GT为前景的区域nn = resize_pred[torch.le(gt, self.threshold)]  # 对应预测map中GT为背景的区域pp_hist = torch.histc(pp, bins=255, min=0.0, max=1.0)nn_hist = torch.histc(nn, bins=255, min=0.0, max=1.0)# Sort according to the prediction probability from large to smallpp_hist_flip = torch.flipud(pp_hist)nn_hist_flip = torch.flipud(nn_hist)pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)recall = pp_hist_flip_cum / (gt_num + 1e-4)if self.precision_cum is None:self.precision_cum = torch.full_like(precision, fill_value=0.)if self.recall_cum is None:self.recall_cum = torch.full_like(recall, fill_value=0.)if self.num_cum is None:self.num_cum = torch.zeros([1], dtype=gt.dtype, device=gt.device)self.precision_cum += precisionself.recall_cum += recallself.num_cum += batch_sizedef compute(self):pre_mean = self.precision_cum / self.num_cumrec_mean = self.recall_cum / self.num_cumf1_mean = (1 + 0.3) * pre_mean * rec_mean / (0.3 * pre_mean + rec_mean + 1e-8)max_f1 = torch.amax(f1_mean).item()return max_f1def reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.precision_cum)torch.distributed.all_reduce(self.recall_cum)torch.distributed.all_reduce(self.num_cum)def __str__(self):max_f1 = self.compute()return f'maxF1: {max_f1:.3f}'

4. 预测

...
threshold = 0.5
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device)
...pred = model(img)   
pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
# 生成mask结果
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)     
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]    
# 应该就是将原图中背景像素点变为0,前景像素点不变,  没去仔细查此行功能,待解决
# 保存
cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))

预测一张图片时,生成图如下:

训练以及验证时,数据集中的标签: 

这篇关于03 U2net的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/396699

相关文章

cross-plateform 跨平台应用程序-03-如果只选择一个框架,应该选择哪一个?

跨平台系列 cross-plateform 跨平台应用程序-01-概览 cross-plateform 跨平台应用程序-02-有哪些主流技术栈? cross-plateform 跨平台应用程序-03-如果只选择一个框架,应该选择哪一个? cross-plateform 跨平台应用程序-04-React Native 介绍 cross-plateform 跨平台应用程序-05-Flutte

FreeRTOS内部机制学习03(事件组内部机制)

文章目录 事件组使用的场景事件组的核心以及Set事件API做的事情事件组的特殊之处事件组为什么不关闭中断xEventGroupSetBitsFromISR内部是怎么做的? 事件组使用的场景 学校组织秋游,组长在等待: 张三:我到了 李四:我到了 王五:我到了 组长说:好,大家都到齐了,出发! 秋游回来第二天就要提交一篇心得报告,组长在焦急等待:张三、李四、王五谁先写好就交谁的

Vue day-03

目录 Vue常用特性 一.响应更新 1. 1 v-for更新监测 1.2 v-for就地更新 1.3 什么是虚拟DOM 1.4 diff算法更新虚拟DOM 总结:key值的作用和注意点: 二.过滤器 2.1 vue过滤器-定义使用 2.2 vue过滤器-传参和多过滤器 三. 计算属性(computed) 3.1 计算属性-定义使用 3.2 计算属性-缓存 3.3 计算属

【SpringMVC学习03】-SpringMVC的配置文件详解

在SpringMVC的各个组件中,处理器映射器、处理器适配器、视图解析器称为springmvc的三大组件。其实真正需要程序员开发的就两大块:一个是Handler,一个是jsp。 在springMVC的入门程序中,SpringMVC的核心配置文件——springmvc.xml为: <?xml version="1.0" encoding="UTF-8"?><beans xmlns="http:

浙大数据结构——03-树1 树的同构

这道题我依然采用STL库的map,从而大幅减少了代码量 简单说一下思路,两棵树是否同构,只需比较俩树字母相同的结点是否同构,即是否左==左,右==右或者左==右,右==左。 1、条件准备 atree和btree是存两个数结点字母,第几个就存输入的第几个结点的字母。 map通过结点的字母作为键,从而找到两个子节点的信息 都要用char类型 #include <iostream>#inc

python+selenium2轻量级框架设计-03读取配置文件

任何一个项目,都涉及到了配置文件和管理和读写,Python支持很多配置文件的读写,这里介绍读取ini文件。 以读取url和浏览器作为例子 #浏览器引擎类import configparser,time,osfrom selenium import webdriverfrom framework.logger import Loggerlogger = Logger(logger='

python+selenium2学习笔记unittest-03断言

断言的方法网上归纳的很多主要有以下这些 断言语法解释assertEqual(a, b) 判断a==bassertNotEqual(a, b)判断a!=bassertTrue(x)bool(x) is TrueassertFalse(x)bool(x) is FalseassertIs(a, b)a is bassertIsNot(a, b) a is not bassertIsNone(x) x

C++入门(03)萌新问题多(一)(未完待续)

文章目录 1. 一闪而过使用system("pause")使用cin.get() 1. 一闪而过 .exe 在用户计算机上运行后“一闪而过”,是因为控制台程序没有专门的用户图形界面,程序执行完所有代码后默认完成任务自动关闭 使用system(“pause”) 在程序的结尾处加入 system(“pause”),程序在执行完毕后等待用户按任意键继续。这是最简单的方法。 使

三文带你轻松上手鸿蒙的AI语音03-文本合成声音

三文带你轻松上手鸿蒙的AI语音03-文本合成声音 前言 接上文 三文带你轻松上手鸿蒙的AI语音02-声音文件转文本 HarmonyOS NEXT 提供的AI 文本合并语音功能,可以将一段不超过10000字符的文本合成为语音并进行播报。 场景举例 手机在无网状态下,系统应用无障碍(屏幕朗读)接入文本转语音能力,为视障人士提供播报能力。类似微信读书,可以实现将文章内容通过语音朗读,可以

读软件设计的要素03概念的组合

1. 概念的组合 1.1. 概念不像程序那样,可以用较大的包含较小的 1.1.1. 每个概念对用户来说都是平等的,软件或系统就是一组串联运行的概念组合 1.2. 概念是通过操作来同步组合的 1.2.1. 同步并不增加新的概念操作,但会限制已有的操作,从而消除一些独立概念可能会出现的操作序列 1.3. 在自由组合中,概念彼此独立,仅受一些记录的约束,这些约束是为了确保概念对事物观点的一