本文主要是介绍03 U2net,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
目录
一、 理论知识
1. 网络架构
2. RSU
3. 显著特征融合模块
4. 损失计算编辑
二、 代码实现
0. DUTS数据集
1. transforms.ToTensor()
2. train
3. 验证
4. 预测
原文链接https://blog.csdn.net/qq_37541097/article/details/126255483
一、 理论知识
U2Net是针对Salient Object Detetion(SOD)即显著性目标检测任务。
显著性目标检测任务与语义分割任务相似,只不过显著性目标检测任务是二分类任务,它的任务是将图片中最吸引人的目标或区域分割出来,故只有前景和背景两类。
1. 网络架构
在大的UNet中嵌入了一堆小UNet
2. RSU
En_1
、En_2
、En_3
、En_4
、De_1
、De_2
、De_3
、De_4
采用的是同一种Block,
只是深度不同。
Block
就是论文中提出的ReSidual U-block
简称RSU(也就是小Unet)
En_1
和De_1
采用的是RSU-7
,En_2
和De_2
采用的是RSU-6
,En_3
和De_3
采用的是RSU-5
,En_4
和De_4
采用的是RSU-4
En_5
、En_6
和De_5
三个模块采用的是RSU-4F
,RSU-4F
和RSU-4
两者结构并不相同。- 带参数
d
的卷积层全部是膨胀卷积,d
为膨胀系数
深度为7的Block ---------- RSU-7
RSU-4F
将采样层全部替换成了膨胀卷积:
3. 显著特征融合模块
saliency map fusion module
即显著特征融合模块 ---------- 最终输出模块
对照1中网络结构图看
- 收集De_1、De_2、De_3、De_4、De_5以及En_6的输出;
- 分别通过一个3x3的卷积层得到channel为1的特征图;
- 双线性插值缩放到输入图片大小得到Sup1、Sup2、Sup3、Sup4、Sup5和Sup6;
- 6个特征图进行Concat拼接;
- 通过一个1x1的卷积层以及Sigmiod激活函数得到最终的预测概率图
4. 损失计算
评价指标:
二、 代码实现
显著性目标检测只区分背景和前景,本项目使用的数据集是DUTS数据集
target:黑白图二维(H,W)------- 背景:0 ; 前景:255 ; 边缘地方:小的数值
U2Net没有torch官方实现;需要使用时修改项目中的代码
0. DUTS数据集
目录结构如下:
├── DUTS-TR
│ ├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
│ └── DUTS-TR-Mask: 该文件夹存放对应训练图片的GT标签(Mask蒙板形式)
│
└── DUTS-TE├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片└── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的GT标签(Mask蒙板形式)
mask只区分前景背景:(如下图)
该数据集没有用到调色板!!!!!
自定义DUTS数据集,详见my_dataset.py文件
1. transforms.ToTensor()
针对image:
- 转变形状,将H,W,C转变成C,H,W
- 转变数据格式,返回的tensor格式
- 归一化,可以将数据的范围变成[0,1]----所有像素值÷255
针对target:
分割类任务,target是二维(H,W)
- 经过.ToTensor()后,自动增加维度,变为(1,H,W)
- 转变数据格式,返回的tensor格式
- 归一化,可以将数据的范围变成[0,1]----所有像素值÷255
2. train
由于torch官方没有实现该模型,加载作者训练好的.pth文件
# ********************************************* 加载.pth参数
# 建立model
model = u2net_full()
# 查看model的初始化参数值
old_dic = model.state_dict()
# 加载.pth参数文件
weight_path = './u2net_full.pth'
pre_dic = torch.load(weight_path, map_location=device)
# 将.pth参数加载到model中,只会将字典名称完全一样以及shape相同的加载进去
# 返回model中未加载成功的参数 以及 .pth中多余的参数(与model不匹配)------strict=False
missing_keys, unexpected_keys = model.load_state_dict(pre_dic, strict=False)
print(missing_keys, unexpected_keys)
# 再次查看model的参数值,检查是否已经更换成功
new_dic = model.state_dict()
model.to(device)
损失计算:
# inputs就是model预测的结果-----多个特征图查看原理部分损失计算
def criterion(inputs, target):losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]total_loss = sum(losses) # 求和return total_loss
3. 验证
需要建立两个评价指标------- MeanAbsoluteError 和 F1Score
mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
...
...
def evaluate(,,,,):model.eval()mae_metric = utils.MeanAbsoluteError()f1_metric = utils.F1Score()...with torch.no_grad():for images, targets in data_loader:images, targets = images.to(device), targets.to(device)output = model(images)mae_metric.update(output, targets)f1_metric.update(output, targets)return mae_metric, f1_metric
class MeanAbsoluteError(object):def __init__(self):self.mae_list = []def update(self, pred: torch.Tensor, gt: torch.Tensor):batch_size, c, h, w = gt.shapeassert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)error_pixels = torch.sum(torch.abs(resize_pred - gt), dim=(1, 2, 3)) / (h * w)self.mae_list.extend(error_pixels.tolist())def compute(self):mae = sum(self.mae_list) / len(self.mae_list)return maedef gather_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()gather_mae_list = []for i in all_gather(self.mae_list):gather_mae_list.extend(i)self.mae_list = gather_mae_listdef __str__(self):mae = self.compute()return f'MAE: {mae:.3f}'class F1Score(object):"""refer: https://github.com/xuebinqin/DIS/blob/main/IS-Net/basics.py"""def __init__(self, threshold: float = 0.5):self.precision_cum = Noneself.recall_cum = Noneself.num_cum = Noneself.threshold = thresholddef update(self, pred: torch.Tensor, gt: torch.Tensor):batch_size, c, h, w = gt.shapeassert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)gt_num = torch.sum(torch.gt(gt, self.threshold).float())pp = resize_pred[torch.gt(gt, self.threshold)] # 对应预测map中GT为前景的区域nn = resize_pred[torch.le(gt, self.threshold)] # 对应预测map中GT为背景的区域pp_hist = torch.histc(pp, bins=255, min=0.0, max=1.0)nn_hist = torch.histc(nn, bins=255, min=0.0, max=1.0)# Sort according to the prediction probability from large to smallpp_hist_flip = torch.flipud(pp_hist)nn_hist_flip = torch.flipud(nn_hist)pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)recall = pp_hist_flip_cum / (gt_num + 1e-4)if self.precision_cum is None:self.precision_cum = torch.full_like(precision, fill_value=0.)if self.recall_cum is None:self.recall_cum = torch.full_like(recall, fill_value=0.)if self.num_cum is None:self.num_cum = torch.zeros([1], dtype=gt.dtype, device=gt.device)self.precision_cum += precisionself.recall_cum += recallself.num_cum += batch_sizedef compute(self):pre_mean = self.precision_cum / self.num_cumrec_mean = self.recall_cum / self.num_cumf1_mean = (1 + 0.3) * pre_mean * rec_mean / (0.3 * pre_mean + rec_mean + 1e-8)max_f1 = torch.amax(f1_mean).item()return max_f1def reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.precision_cum)torch.distributed.all_reduce(self.recall_cum)torch.distributed.all_reduce(self.num_cum)def __str__(self):max_f1 = self.compute()return f'maxF1: {max_f1:.3f}'
4. 预测
...
threshold = 0.5
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device)
...pred = model(img)
pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
# 生成mask结果
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]
# 应该就是将原图中背景像素点变为0,前景像素点不变, 没去仔细查此行功能,待解决
# 保存
cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))
预测一张图片时,生成图如下:
训练以及验证时,数据集中的标签:
这篇关于03 U2net的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!