u2net实现视频图像分割(从原理到实践)

2023-11-10 15:30

本文主要是介绍u2net实现视频图像分割(从原理到实践),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、U2net简单介绍

1、U2net网络结构:

image.png

整个网络成对称U型结构,使用的是经典的编解码结构,在每一个Sup内部又是U形结构,采用的是深监督的方式,有效结合浅层和深层的语义信息。进行了5次下采样和5次上采样,上采样的方式通过torch.nn.functional.interpolate()函数实现,下采样通过torch.nn.MaxPool2d() 步长为2的最大平均池化实现。在每一个En_x中使用RSU模块,RSU模块的结构如下:

image.png

 

image.png

RSU模块的作用是获得在不同阶段的多尺度特征(L指的是在编码器中的层数,Cin和Cout分别代表输入通道核输出通道,M表示RSU内部层中的通道数),该结构主要由3部分构成:

(1)输入的卷积层,将输入的特征图转为和输出相同的通道数的中间映射用于局部特征提取

(2)一种高度为L的对称式编解码结构,将中间映射作为输入,提取和学习多尺度的语义信息

(3)用于融合局部特征和所尺度特征的残差结构

在U2Net中同时使用了add和Concate

2、损失函数:

image.png

因为有6个Sup,所以有6个损失函数,每一个Sup的损失使用的是标准交叉熵损失函数

image.png

二、代码部分:

网络部分对照着图看还是比较清晰的,其余大部分文件添加了注释,方便自己二次回顾

1、U2net.py

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as Fclass REBNCONV(nn.Module):                                                          #CBLdef __init__(self,in_ch=3,out_ch=3,dirate=1):super(REBNCONV,self).__init__()self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)self.bn_s1 = nn.BatchNorm2d(out_ch)self.relu_s1 = nn.ReLU(inplace=True)def forward(self,x):hx = xxout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))return xoutdef _upsample_like(src,tar):# src = F.upsample(src,size=tar.shape[2:],mode='bilinear')src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     #     https://www.cnblogs.com/wanghui-garcia/p/11399034.html# nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)return src### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU7,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx = self.pool4(hx4)hx5 = self.rebnconv5(hx)hx = self.pool5(hx5)hx6 = self.rebnconv6(hx)hx7 = self.rebnconv7(hx6)                                      #dialation=2hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))hx6dup = _upsample_like(hx6d,hx5)hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU6,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx = self.pool4(hx4)hx5 = self.rebnconv5(hx)hx6 = self.rebnconv6(hx5)hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU5,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx5 = self.rebnconv5(hx4)hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU4,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx4 = self.rebnconv4(hx3)hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU4F,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx2 = self.rebnconv2(hx1)hx3 = self.rebnconv3(hx2)hx4 = self.rebnconv4(hx3)hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))return hx1d + hxin##### U^2-Net ####
class U2NET(nn.Module):def __init__(self,in_ch=3,out_ch=1):super(U2NET,self).__init__()self.stage1 = RSU7(in_ch,32,64)self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage2 = RSU6(64,32,128)self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage3 = RSU5(128,64,256)self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage4 = RSU4(256,128,512)self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage5 = RSU4F(512,256,512)self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage6 = RSU4F(512,256,512)# decoderself.stage5d = RSU4F(1024,256,512)self.stage4d = RSU4(1024,128,256)self.stage3d = RSU5(512,64,128)self.stage2d = RSU6(256,32,64)self.stage1d = RSU7(128,16,64)self.side1 = nn.Conv2d(64,out_ch,3,padding=1)self.side2 = nn.Conv2d(64,out_ch,3,padding=1)self.side3 = nn.Conv2d(128,out_ch,3,padding=1)self.side4 = nn.Conv2d(256,out_ch,3,padding=1)self.side5 = nn.Conv2d(512,out_ch,3,padding=1)self.side6 = nn.Conv2d(512,out_ch,3,padding=1)self.outconv = nn.Conv2d(6,out_ch,1)def forward(self,x):hx = x#stage 1hx1 = self.stage1(hx)hx = self.pool12(hx1)#stage 2hx2 = self.stage2(hx)hx = self.pool23(hx2)#stage 3hx3 = self.stage3(hx)hx = self.pool34(hx3)#stage 4hx4 = self.stage4(hx)hx = self.pool45(hx4)#stage 5hx5 = self.stage5(hx)hx = self.pool56(hx5)#stage 6hx6 = self.stage6(hx)hx6up = _upsample_like(hx6,hx5)#-------------------- decoder --------------------hx5d = self.stage5d(torch.cat((hx6up,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))#side outputd1 = self.side1(hx1d)d2 = self.side2(hx2d)d2 = _upsample_like(d2,d1)d3 = self.side3(hx3d)d3 = _upsample_like(d3,d1)d4 = self.side4(hx4d)d4 = _upsample_like(d4,d1)d5 = self.side5(hx5d)d5 = _upsample_like(d5,d1)d6 = self.side6(hx6)d6 = _upsample_like(d6,d1)d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)### U^2-Net small ###
class U2NETP(nn.Module):def __init__(self,in_ch=3,out_ch=1):super(U2NETP,self).__init__()self.stage1 = RSU7(in_ch,16,64)self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage2 = RSU6(64,16,64)self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage3 = RSU5(64,16,64)self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage4 = RSU4(64,16,64)self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage5 = RSU4F(64,16,64)self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage6 = RSU4F(64,16,64)# decoderself.stage5d = RSU4F(128,16,64)self.stage4d = RSU4(128,16,64)self.stage3d = RSU5(128,16,64)self.stage2d = RSU6(128,16,64)self.stage1d = RSU7(128,16,64)self.side1 = nn.Conv2d(64,out_ch,3,padding=1)self.side2 = nn.Conv2d(64,out_ch,3,padding=1)self.side3 = nn.Conv2d(64,out_ch,3,padding=1)self.side4 = nn.Conv2d(64,out_ch,3,padding=1)self.side5 = nn.Conv2d(64,out_ch,3,padding=1)self.side6 = nn.Conv2d(64,out_ch,3,padding=1)self.outconv = nn.Conv2d(6,out_ch,1)def forward(self,x):hx = x#stage 1hx1 = self.stage1(hx)hx = self.pool12(hx1)#stage 2hx2 = self.stage2(hx)hx = self.pool23(hx2)#stage 3hx3 = self.stage3(hx)hx = self.pool34(hx3)#stage 4hx4 = self.stage4(hx)hx = self.pool45(hx4)#stage 5hx5 = self.stage5(hx)hx = self.pool56(hx5)#stage 6hx6 = self.stage6(hx)hx6up = _upsample_like(hx6,hx5)#decoderhx5d = self.stage5d(torch.cat((hx6up,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))#side outputd1 = self.side1(hx1d)d2 = self.side2(hx2d)d2 = _upsample_like(d2,d1)d3 = self.side3(hx3d)d3 = _upsample_like(d3,d1)d4 = self.side4(hx4d)d4 = _upsample_like(d4,d1)d5 = self.side5(hx5d)d5 = _upsample_like(d5,d1)d6 = self.side6(hx6)d6 = _upsample_like(d6,d1)d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)

2、data_loader.py

# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color                                            #scikit-image是基于scipy的一款图像处理包,它将图片作为numpy数组进行处理
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image#==========================dataset load==========================
class RescaleT(object):                                                                      #此处等比缩放原始图像位指定的输出大def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_size                                                        #获得输出的图片的大小def __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'],sample['label']                #获取到图片的索引、图片名和标签h, w = image.shape[:2]                                                                #获取图片的形状if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_size                          #根据输出图片的大小重新分配宽和高else:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]# img = transform.resize(image,(new_h,new_w),mode='constant')# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)img = transform.resize(image,(self.output_size,self.output_size),mode='constant')      #此处等比缩放原始图像位指定的输出大小lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)  #等比缩放标签图像    #skimage.transform import resizereturn {'imidx':imidx, 'image':img,'label':lbl}class Rescale(object):                                                                    #重新缩放至指定大小def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_sizedef __call__(self,sample):                                                             #使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。imidx, image, label = sample['imidx'], sample['image'],sample['label']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_sizeelse:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]img = transform.resize(image,(new_h,new_w),mode='constant')lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)return {'imidx':imidx, 'image':img,'label':lbl}class RandomCrop(object):                                                                   #返回经过随机裁剪后的图像和标签,指定输出大小def __init__(self,output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h, left: left + new_w]                                 #对原始的图片进行随机裁label = label[top: top + new_h, left: left + new_w]                                 #对原始标签随机裁剪return {'imidx':imidx,'image':image, 'label':label}                                 #返回经过随机裁剪后的图像和标签class ToTensor(object):                                                             #对图像和标签归一化"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']tmpImg = np.zeros((image.shape[0],image.shape[1],3))tmpLbl = np.zeros(label.shape)image = image/np.max(image)                                                         #归一化图片if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}class ToTensorLab(object):"""Convert ndarrays in sample to Tensors."""def __init__(self,flag=0):self.flag = flagdef __call__(self, sample):imidx, image, label =sample['imidx'], sample['image'], sample['label']tmpLbl = np.zeros(label.shape)if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)# change the color spaceif self.flag == 2: # with rgb and Lab colorstmpImg = np.zeros((image.shape[0],image.shape[1],6))tmpImgt = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImgt[:,:,0] = image[:,:,0]tmpImgt[:,:,1] = image[:,:,0]tmpImgt[:,:,2] = image[:,:,0]else:tmpImgt = imagetmpImgtl = color.rgb2lab(tmpImgt)# nomalize image to range [0,1]tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])elif self.flag == 1: #with Lab colortmpImg = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImg[:,:,0] = image[:,:,0]tmpImg[:,:,1] = image[:,:,0]tmpImg[:,:,2] = image[:,:,0]else:tmpImg = imagetmpImg = color.rgb2lab(tmpImg)# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])else: # with rgb colortmpImg = np.zeros((image.shape[0],image.shape[1],3))image = image/np.max(image)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}class SalObjDataset(Dataset):                                                                   #返回归一化后的图片索引,图片,标签图片def __init__(self,img_name_list,lbl_name_list,transform=None):# self.root_dir = root_dir# self.image_name_list = glob.glob(image_dir+'*.png')# self.label_name_list = glob.glob(label_dir+'*.png')self.image_name_list = img_name_list                                                     #获取到所有的图片名绝对路径self.label_name_list = lbl_name_list                                                     #获取到所有的标签绝对路径self.transform = transform                                                               #transform包括裁剪缩放转tensordef __len__(self):return len(self.image_name_list)def __getitem__(self,idx):# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])image = io.imread(self.image_name_list[idx])                                             #通过每张的绝对路径读取到每一张图片# print(type(image))                                                                       #<class 'numpy.ndarray'> BGR格式的图片# import cv2# cv2.imshow("",cv2.cvtColor(np.uint8(image),cv2.COLOR_BGR2RGB))# cv2.waitKey(0)# cv2.destroyAllWindows()# print(image.shape)                                                                       #(375, 500, 3)大小不固定,单都是3通道# print("=======================================================")imname = self.image_name_list[idx]imidx = np.array([idx])                                                                  #图片的索引转化numpy的数组if(0==len(self.label_name_list)):                                                        #如果没有标签则创建一个0标签label_3 = np.zeros(image.shape)else:                                                                                    #如果有标签则获取对应的标签label_3 = io.imread(self.label_name_list[idx])label = np.zeros(label_3.shape[0:2])                                                     #将标签数据用不同维度的0表示if(3==len(label_3.shape)):label = label_3[:,:,0]elif(2==len(label_3.shape)):label = label_3if(3==len(image.shape) and 2==len(label.shape)):label = label[:,:,np.newaxis]                                                        #np.newaxis的作用就是在这一位置增加一个一维,这一位置指的是np.newaxis所在的位置elif(2==len(image.shape) and 2==len(label.shape)):image = image[:,:,np.newaxis]label = label[:,:,np.newaxis]sample = {'imidx':imidx, 'image':image, 'label':label}if self.transform:sample = self.transform(sample)                                                      #对图像transformreturn sample

3、u2net_train.py

import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transformsimport numpy as np
import glob
import osfrom my_U2_Net.data_loader import RescaleT,RandomCrop,ToTensorLab,SalObjDataset,ToTensor,Rescale
from my_U2_Net.model import U2NET,U2NETP# ------- 1. define loss function --------bce_loss = nn.BCELoss(reduction='mean')def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):loss0 = bce_loss(d0,labels_v)loss1 = bce_loss(d1,labels_v)loss2 = bce_loss(d2,labels_v)loss3 = bce_loss(d3,labels_v)loss4 = bce_loss(d4,labels_v)loss5 = bce_loss(d5,labels_v)loss6 = bce_loss(d6,labels_v)loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6              #在网路结构中共有6个sup所以有6个损失函数print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item()))return loss0, lossdef main():# ------- 2. set the directory of training dataset --------model_name = 'u2net' #'u2netp'                                     #选取的网络模型的种类# params_path = os.path.join("../saved_models", model_name,model_name.pth)data_dir = r'F:\PASCAL_VOC\VOCdevkit\VOC2007\my_segmentations'     #包含训练图片和分割标签的上一级目录tra_image_dir = 'JPEGImages'                                       #图片所在的目录名tra_label_dir = 'SegmentationClass'                                #标签所在的目录名image_ext = '.jpg'label_ext = '.png'                                                 #标签的后缀名model_dir = './saved_models/' + model_name +'/'                    #保存的参数模型所在的文件夹名params_path = model_dir +model_name  +".pth"                       #保存的参数参数的相对路径epoch_num = 100000                                                 #训练的总轮次batch_size_train = 2                                               #训练的批次batch_size_val = 1                                                 #验证的批次train_num = 0val_num = 0# tra_img_name_list = glob.glob(data_dir + "\\" + tra_image_dir + '*')tra_img_name_list = glob.glob(os.path.join(data_dir, tra_image_dir,'*'))   #训练的图片所在的路径,glob.glob()将图片的绝对路径保存到一个列表print("hahah")# print(tra_img_name_list)                                                   #包含所有训练图片绝对路径的列表# print("-------------------------------------------------------------------------------------------------")tra_lbl_name_list = []for img_path in tra_img_name_list:                                         #遍历每一个图片的绝对路径img_name = img_path.split("\\")[-1]                                    #取出图片的名字,如:003000.jpgaaa = img_name.split(".")bbb = aaa[0:-1]                                                        #['000032']# print(bbb)#去除后缀的图片名imidx = bbb[0]                                                         #000032# print(imidx)                                                         #000032# print(len(bbb))                                                      #1for i in range(1,len(bbb)):print(i)imidx = imidx + "." + bbb[i]print(imidx,"**********")tra_lbl_name_list.append(data_dir+ "\\"  + tra_label_dir+ "\\"  + imidx + label_ext)print(tra_lbl_name_list)                                                   #标签的绝对路径的列表,和训练图片的绝对路径一一对应print("---")print("train images: ", len(tra_img_name_list))                            #422print("train labels: ", len(tra_lbl_name_list))print("---")train_num = len(tra_img_name_list)                                         #训练的图片的总数salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,lbl_name_list=tra_lbl_name_list,transform=transforms.Compose([RescaleT(320),RandomCrop(288),ToTensorLab(flag=0)]))     #RescaleT(320)等比缩放为指定的大小,RandomCrop(288)随机裁剪为指定的大小salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)# ------- 3. define model --------# define the netif(model_name=='u2net'):net = U2NET(3, 1)                                            #指定输入通道核输出通道的大小elif(model_name=='u2netp'):                                      #网络实例化net = U2NETP(3,1)if torch.cuda.is_available():net.cuda()                                                   #网络转移至GPUif os.path.exists(params_path):                                  #加载训练好的模型参数net.load_state_dict(torch.load(params_path))else:print("No parameters!")# ------- 4. define optimizer --------print("---define optimizer...")optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)# ------- 5. training process --------print("---start training...")ite_num = 0running_loss = 0.0running_tar_loss = 0.0ite_num4val = 0save_frq = 2000              #save the model every 2000 iterationsfor epoch in range(0, epoch_num):net.train()                                                  #训练模式for i, data in enumerate(salobj_dataloader):ite_num = ite_num + 1ite_num4val = ite_num4val + 1inputs, labels = data['image'], data['label']            #获取图片和标签inputs = inputs.type(torch.FloatTensor)                  #转为tensor类型labels = labels.type(torch.FloatTensor)# wrap them in Variableif torch.cuda.is_available():                            #转移到cudainputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),requires_grad=False)else:inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)# y zero the parameter gradientsoptimizer.zero_grad()                                    #优化器清空梯度# forward + backward + optimized0, d1, d2, d3, d4, d5, d6 = net(inputs_v)               #网络输出loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v) #6个sup作损失loss.backward()                                          #反向求导更新梯度optimizer.step()                                         #下一步# # print statisticsrunning_loss += loss.item()                              #总损失running_tar_loss += loss2.item()# delete temporary outputs and lossdel d0, d1, d2, d3, d4, d5, d6, loss2, lossprint("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))if ite_num % save_frq == 0:# torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))running_loss = 0.0running_tar_loss = 0.0net.train()  # resume trainite_num4val = 0# torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))torch.save(net.state_dict(),params_path)
if __name__ == "__main__":main()

4、u2net_test.py

import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optimimport numpy as np
from PIL import Image
import globfrom my_U2_Net.data_loader import RescaleT
from my_U2_Net.data_loader import ToTensor
from my_U2_Net.data_loader import ToTensorLab
from my_U2_Net.data_loader import SalObjDataset                        #加载数据用(返回图片的索引,归一化后的图片,标签)from my_U2_Net.model import U2NET # full size version 173.6 MB         #导入两个网络
from my_U2_Net.model import U2NETP # small version u2net 4.7 MB# normalize the predicted SOD probability map
def normPRED(d):                                           #归一化ma = torch.max(d)mi = torch.min(d)dn = (d-mi)/(ma-mi)return dndef save_output(image_name,pred,d_dir):predict = pred.squeeze()                               #删除单维度predict_np = predict.cpu().data.numpy()                #转移到CPU上im = Image.fromarray(predict_np*255).convert('RGB')    #转为PIL,从归一化的图片恢复到正常0到255之间img_name = image_name.split("\\")[-1]                  #取出后缀类型# print(image_name)# print(img_name)image = io.imread(image_name)                          #io.imread读出图片格式是uint8(unsigned int);value是numpy array;图像数据是以RGB的格式进行存储的,通道值默认范围0-255imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)# pb_np = np.array(imo)                                  #多余的aaa = img_name.split(".")                              #图片名字被切分为一个列表bbb = aaa[0:-1]                                        #取出图片名称的前缀# print(aaa)                                             #['5', 'jpg']# print(bbb)                                             #['5']# print("---------------------------------------------")imidx = bbb[0]for i in range(1,len(bbb)):imidx = imidx + "." + bbb[i]imo.save(d_dir+imidx+'.png')                           #保存图片到指定路径def main():# --------- 1. get image path and name ---------model_name='u2net'#u2netp                              #保存的模型的名称image_dir = './test_data/test_images/'                 #将要预测的图片所在的文件夹路径prediction_dir = './test_data/' + model_name + '_results/'#预测结果的保存的文件夹路径# model_dir = '../saved_models/'+ model_name + '/' + model_name + '.pth'# model_dir = r"../saved_models/u2net/u2net_bce_itr_422_train_3.743319_tar_0.546805.pth"model_dir = "\my_U2_Net\saved_models\u2net\u2net.pth"          #模型参数的路径img_name_list = glob.glob(image_dir + '*')             #图片文件夹下的所有数据(携带路径)print(img_name_list)# --------- 2. dataloader ---------#1. dataloadertest_salobj_dataset = SalObjDataset(img_name_list=img_name_list,lbl_name_list=[],transform=transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))test_salobj_dataloader = DataLoader(test_salobj_dataset,batch_size=1,shuffle=False,num_workers=1)     #加载数据# --------- 3. model define ---------if(model_name=='u2net'):                               #分辨使用的是哪一个模型参数print("...load U2NET---173.6 MB")net = U2NET(3,1)elif(model_name=='u2netp'):print("...load U2NEP---4.7 MB")net = U2NETP(3,1)net.load_state_dict(torch.load(model_dir))             #加载训练好的模型if torch.cuda.is_available():net.cuda()                                         #网络转移至GPUnet.eval()                                             #测评模式# --------- 4. inference for each image ---------for i_test, data_test in enumerate(test_salobj_dataloader):print("inferencing:",img_name_list[i_test].split("/")[-1])   #test_images\5.jpg# print(data_test)                                   #'imidx': tensor([[0]], dtype=torch.int32), 'image': tensor([[[[ 1.4051,  ...'label': tensor([[[[0., 0., 0.,  ...,inputs_test = data_test['image']                   #测试的是图片inputs_test = inputs_test.type(torch.FloatTensor)  #转为浮点型if torch.cuda.is_available():inputs_test = Variable(inputs_test.cuda())#Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,# tensor不能反向传播,variable可以反向传播。它会逐渐地生成计算图。# 这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,# 一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力else:inputs_test = Variable(inputs_test)d1,d2,d3,d4,d5,d6,d7 = net(inputs_test)             #将图片传入网络# normalizationpred = d1[:,0,:,:]pred = normPRED(pred)                               #对预测的结果做归一化# save results to test_results foldersave_output(img_name_list[i_test],pred,prediction_dir)  #原始图片名、预测结果,预测图片的保存目录   #save_output保存预测的输出值del d1,d2,d3,d4,d5,d6,d7                            #del 用于删除对象。在 Python,一切都是对象,因此 del 关键字可用于删除变量、列表或列表片段等。if __name__ == "__main__":main()

注:本次训练使用的是VOC2007的数据进行训练的,加载了预训练模型

输出结果为:

原图

OIP (2).jpg

得到掩码图:

OIP (2).png

提取目标:

crop.py

# -*- coding: utf-8 -*-import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import osdef crop(img_file, mask_file):name, *_ = img_file.split(".")                                        #将后缀划分开img_array = np.array(Image.open(img_file))                            #真实图片的PIL转到numpy类型mask = np.array(Image.open(mask_file))                                #打开掩码图片res = np.concatenate((img_array, mask[:, :, [0]]), -1)               #将原图和掩码进行数组拼接img = Image.fromarray(res.astype('uint8'), mode='RGBA')              #数组转为PIL格式img.show()return imgif __name__ == "__main__":model = "u2net"# model = "u2netp"img_root = "test_data/test_images"                                    #真实图片的保存路径mask_root = "test_data/{}_results".format(model)                      #掩码的根目录crop_root = "test_data/{}_crops".format(model)                        #裁剪图片的保存的目录os.makedirs(crop_root, mode=0o775, exist_ok=True)                     #创建保存图片的路径# name:想要创建的目录名,modemode:要为目录设置的权限数字模式,# 默认的模式为 0o777 (八进制)。exist_ok:是否在目录存在时触发异常。# 如果exist_ok为False(默认值),则在目标目录已存在的情况下触发FileExistsError异常;# 如果exist_ok为True,则在目标目录已存在的情况下不会触发FileExistsError异常。for img_file in os.listdir(img_root):                                #遍历所有的源图片print("crop image {}".format(img_file))name, *_ = img_file.split(".")                                   #划分出图片的名字和后缀res = crop(img_file=os.path.join(img_root,  img_file),mask_file=os.path.join(mask_root, name + ".png"))                                                                #调用自定义的crop()函数res.save(os.path.join(crop_root, name + "_crop.png"))            #保存图片的到指定的保存路径

得到图片如下:

OIP (2)_crop.png

5、video_for_video.py

为了实现视频分割,整合了上边的几个文件,由于本代码还未进行优化。。。但勉强测试下视频分割还是可以的,后续必要是再进行优化。

# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color                                            #scikit-image是基于scipy的一款图像处理包,它将图片作为numpy数组进行处理
import numpy as np
import random
import math
import matplotlib.pyplot as pltfrom torchvision import transforms, utils
from PIL import Image
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import glob
from my_U2_Net.model import U2NET,U2NETP                    #导入两个网络
import cv2capture = cv2.VideoCapture(0)
class RescaleT(object):                                                                      #此处等比缩放原始图像位指定的输出大def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_size                                                        #获得输出的图片的大小def __call__(self,sample):imidx, image, label,frame = sample['imidx'], sample['image'],sample['label'],sample['frame']                 #获取到图片的索引、图片名和标签h, w = image.shape[:2]                                                                #获取图片的形状if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_size                          #根据输出图片的大小重新分配宽和高else:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]# img = transform.resize(image,(new_h,new_w),mode='constant')# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)img = transform.resize(image,(self.output_size,self.output_size),mode='constant')      #此处等比缩放原始图像位指定的输出大小lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)  #等比缩放标签图像    #skimage.transform import resizereturn {'imidx':imidx, 'image':img,'label':lbl,"frame":frame}class Rescale(object):                                                                    #重新缩放至指定大小def __init__(self,output_size):assert isinstance(output_size,(int,tuple))self.output_size = output_sizedef __call__(self,sample):                                                             #使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。imidx, image, label,frame = sample['imidx'], sample['image'],sample['label'],sample['frame']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]if isinstance(self.output_size,int):if h > w:new_h, new_w = self.output_size*h/w,self.output_sizeelse:new_h, new_w = self.output_size,self.output_size*w/helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]img = transform.resize(image,(new_h,new_w),mode='constant')lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)return {'imidx':imidx, 'image':img,'label':lbl,"frame":frame}class RandomCrop(object):                                                                   #返回经过随机裁剪后的图像和标签,指定输出大小def __init__(self,output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self,sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']if random.random() >= 0.5:image = image[::-1]label = label[::-1]h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h, left: left + new_w]                                 #对原始的图片进行随机裁label = label[top: top + new_h, left: left + new_w]                                 #对原始标签随机裁剪return {'imidx':imidx,'image':image, 'label':label}                                 #返回经过随机裁剪后的图像和标签class ToTensor(object):                                                             #对图像和标签归一化"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):imidx, image, label = sample['imidx'], sample['image'], sample['label']tmpImg = np.zeros((image.shape[0],image.shape[1],3))tmpLbl = np.zeros(label.shape)image = image/np.max(image)                                                         #归一化图片if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}class ToTensorLab(object):"""Convert ndarrays in sample to Tensors."""def __init__(self,flag=0):self.flag = flagdef __call__(self, sample):imidx, image, label,frame =sample['imidx'], sample['image'], sample['label'], sample['frame']tmpLbl = np.zeros(label.shape)if(np.max(label)<1e-6):label = labelelse:label = label/np.max(label)# change the color spaceif self.flag == 2: # with rgb and Lab colorstmpImg = np.zeros((image.shape[0],image.shape[1],6))tmpImgt = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImgt[:,:,0] = image[:,:,0]tmpImgt[:,:,1] = image[:,:,0]tmpImgt[:,:,2] = image[:,:,0]else:tmpImgt = imagetmpImgtl = color.rgb2lab(tmpImgt)# nomalize image to range [0,1]tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])elif self.flag == 1: #with Lab colortmpImg = np.zeros((image.shape[0],image.shape[1],3))if image.shape[2]==1:tmpImg[:,:,0] = image[:,:,0]tmpImg[:,:,1] = image[:,:,0]tmpImg[:,:,2] = image[:,:,0]else:tmpImg = imagetmpImg = color.rgb2lab(tmpImg)# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])else: # with rgb colortmpImg = np.zeros((image.shape[0],image.shape[1],3))image = image/np.max(image)if image.shape[2]==1:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229else:tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225tmpLbl[:,:,0] = label[:,:,0]# change the r,g,b to b,r,g from [0,255] to [0,1]#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))tmpImg = tmpImg.transpose((2, 0, 1))tmpLbl = label.transpose((2, 0, 1))return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl),"frame":frame}class SalObjDataset(Dataset):                                                                   #返回归一化后的图片索引,图片,标签图片def __init__(self,img_name_list,lbl_name_list,transform=None):# self.root_dir = root_dir# self.image_name_list = glob.glob(image_dir+'*.png')# self.label_name_list = glob.glob(label_dir+'*.png')self.image_name_list = img_name_list                                                     #获取到所有的图片名绝对路径self.label_name_list = lbl_name_list                                                     #获取到所有的标签绝对路径self.transform = transform                                                               #transform包括裁剪缩放转tensordef __len__(self):return len(self.image_name_list)def __getitem__(self,idx):# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])# image = io.imread(self.image_name_list[idx])                                             #通过每张的绝对路径读取到每一张图片# print(type(image))                                                                       #<class 'numpy.ndarray'>while True:ref, frame = capture.read()  # 读取某一帧image = frame# image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # 格式转变,BGRtoRGBprint(image.shape)          #(480, 640, 3)print("=======================================================")# imname = self.image_name_list[idx]imidx = np.array([idx])                                                                  #图片的索引转化numpy的数组if(0==len(self.label_name_list)):                                                        #如果没有标签则创建一个0标签label_3 = np.zeros(image.shape)else:                                                                                    #如果有标签则获取对应的标签label_3 = io.imread(self.label_name_list[idx])label = np.zeros(label_3.shape[0:2])                                                     #将标签数据用不同维度的0表示if(3==len(label_3.shape)):label = label_3[:,:,0]elif(2==len(label_3.shape)):label = label_3if(3==len(image.shape) and 2==len(label.shape)):label = label[:,:,np.newaxis]                                                        #np.newaxis的作用就是在这一位置增加一个一维,这一位置指的是np.newaxis所在的位置elif(2==len(image.shape) and 2==len(label.shape)):image = image[:,:,np.newaxis]label = label[:,:,np.newaxis]sample = {'imidx':imidx, 'image':image, 'label':label,"frame":frame}if self.transform:sample = self.transform(sample)                                                      #对图像transformreturn sampledef main():model_name = 'u2net'#u2netp                              #保存的模型的名称model_dir = r"\my_U2_Net\saved_models\u2net\u2net.pth"          #模型参数的路径img_name_list = [i for i in range(10000)]test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,lbl_name_list=[],transform=transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))test_salobj_dataloader = DataLoader(test_salobj_dataset,batch_size=1,shuffle=False,num_workers=1)     #加载数据if(model_name=='u2net'):                               #分辨使用的是哪一个模型参数print("...load U2NET---173.6 MB")net = U2NET(3,1)elif(model_name=='u2netp'):print("...load U2NEP---4.7 MB")net = U2NETP(3,1)net.load_state_dict(torch.load(model_dir))             #加载训练好的模型if torch.cuda.is_available():net.cuda()                                         #网络转移至GPUnet.eval()                                             #测评模式for i_test, data_test in enumerate(test_salobj_dataloader):inputs_test = data_test['image']                   #测试的是图片inputs_test = inputs_test.type(torch.FloatTensor)  #转为浮点型if torch.cuda.is_available():inputs_test = Variable(inputs_test.cuda())#Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,# tensor不能反向传播,variable可以反向传播。它会逐渐地生成计算图。# 这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,# 一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力else:inputs_test = Variable(inputs_test)d1,d2,d3,d4,d5,d6,d7 = net(inputs_test)             #将图片传入网络pred = d1[:,0,:,:]pred = (pred-torch.min(pred))/(torch.max(pred)-torch.min(pred))  #对预测的结果做归一化predict = pred.squeeze()  # 删除单维度predict_np = predict.cpu().data.numpy()  # 转移到CPU上im = Image.fromarray(predict_np * 255).convert('RGB')  # 转为PIL,从归一化的图片恢复到正常0到255之间imo = im.resize((640, 480), resample=Image.BILINEAR)  # 得到的掩码!!!!!!!!# img_array = np.asarray(Image.fromarray(np.uint8(data_test['image'])))img_array = np.uint8(data_test["frame"][0])# print(data_test)# print(data_test["frame"][0])# print(data_test["frame"][0].shape)# cv2.imshow("", np.uint8(data_test["frame"][0]))# cv2.waitKey(0)# cv2.destroyAllWindows()mask = np.asarray(Image.fromarray(np.uint8(imo)))# cv2.imshow("", np.uint8(mask))# cv2.waitKey(0)# cv2.destroyAllWindows()# print(img_array.shape)# print("ccccccccccccccccccc")# res = np.concatenate((img_array, mask[:, :, [0]]), -1)  # 将原图和掩码进行数组拼接# img = cv2.cvtColor(res, cv2.COLOR_RGB2BGRA)# img = Image.fromarray(img.astype('uint8'), mode='RGBA')# img.show()# b, g, r, a = cv2.split(img)# img = cv2.merge([a,r,b,g,])img = Image.fromarray(np.uint8(img_array * (mask / 255)))cv2.imshow("",np.uint8(img))if cv2.waitKey(1) & 0xFF == ord('q'): breakdel d1,d2,d3,d4,d5,d6,d7                            #del 用于删除对象。在 Python,一切都是对象,因此 del 关键字可用于删除变量、列表或列表片段等。if __name__ == "__main__":main()                        #调用

展示效果如下(原本是视频版此处只放效果张图):

如图为展示桌面键盘,背景已经被分离了

image.png

参考资料:

https://arxiv.org/pdf/2005.09007.pdf

https://github.com/NathanUA/U-2-Net

https://zhuanlan.zhihu.com/p/44958351

这篇关于u2net实现视频图像分割(从原理到实践)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/383598

相关文章

Java中使用Java Mail实现邮件服务功能示例

《Java中使用JavaMail实现邮件服务功能示例》:本文主要介绍Java中使用JavaMail实现邮件服务功能的相关资料,文章还提供了一个发送邮件的示例代码,包括创建参数类、邮件类和执行结... 目录前言一、历史背景二编程、pom依赖三、API说明(一)Session (会话)(二)Message编程客

Java中List转Map的几种具体实现方式和特点

《Java中List转Map的几种具体实现方式和特点》:本文主要介绍几种常用的List转Map的方式,包括使用for循环遍历、Java8StreamAPI、ApacheCommonsCollect... 目录前言1、使用for循环遍历:2、Java8 Stream API:3、Apache Commons

C#提取PDF表单数据的实现流程

《C#提取PDF表单数据的实现流程》PDF表单是一种常见的数据收集工具,广泛应用于调查问卷、业务合同等场景,凭借出色的跨平台兼容性和标准化特点,PDF表单在各行各业中得到了广泛应用,本文将探讨如何使用... 目录引言使用工具C# 提取多个PDF表单域的数据C# 提取特定PDF表单域的数据引言PDF表单是一

使用Python实现高效的端口扫描器

《使用Python实现高效的端口扫描器》在网络安全领域,端口扫描是一项基本而重要的技能,通过端口扫描,可以发现目标主机上开放的服务和端口,这对于安全评估、渗透测试等有着不可忽视的作用,本文将介绍如何使... 目录1. 端口扫描的基本原理2. 使用python实现端口扫描2.1 安装必要的库2.2 编写端口扫

PyCharm接入DeepSeek实现AI编程的操作流程

《PyCharm接入DeepSeek实现AI编程的操作流程》DeepSeek是一家专注于人工智能技术研发的公司,致力于开发高性能、低成本的AI模型,接下来,我们把DeepSeek接入到PyCharm中... 目录引言效果演示创建API key在PyCharm中下载Continue插件配置Continue引言

MySQL分表自动化创建的实现方案

《MySQL分表自动化创建的实现方案》在数据库应用场景中,随着数据量的不断增长,单表存储数据可能会面临性能瓶颈,例如查询、插入、更新等操作的效率会逐渐降低,分表是一种有效的优化策略,它将数据分散存储在... 目录一、项目目的二、实现过程(一)mysql 事件调度器结合存储过程方式1. 开启事件调度器2. 创

使用Python实现操作mongodb详解

《使用Python实现操作mongodb详解》这篇文章主要为大家详细介绍了使用Python实现操作mongodb的相关知识,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、示例二、常用指令三、遇到的问题一、示例from pymongo import MongoClientf

SQL Server使用SELECT INTO实现表备份的代码示例

《SQLServer使用SELECTINTO实现表备份的代码示例》在数据库管理过程中,有时我们需要对表进行备份,以防数据丢失或修改错误,在SQLServer中,可以使用SELECTINT... 在数据库管理过程中,有时我们需要对表进行备份,以防数据丢失或修改错误。在 SQL Server 中,可以使用 SE

基于Go语言实现一个压测工具

《基于Go语言实现一个压测工具》这篇文章主要为大家详细介绍了基于Go语言实现一个简单的压测工具,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录整体架构通用数据处理模块Http请求响应数据处理Curl参数解析处理客户端模块Http客户端处理Grpc客户端处理Websocket客户端

Java CompletableFuture如何实现超时功能

《JavaCompletableFuture如何实现超时功能》:本文主要介绍实现超时功能的基本思路以及CompletableFuture(之后简称CF)是如何通过代码实现超时功能的,需要的... 目录基本思路CompletableFuture 的实现1. 基本实现流程2. 静态条件分析3. 内存泄露 bug