DL - 图像分割

2024-04-25 15:04
文章标签 图像 分割 dl

本文主要是介绍DL - 图像分割,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


from transformers import SegformerFeatureExtractor
import PIL.Image#一个把图像转换为数据的工具类
feature_extractor = SegformerFeatureExtractor()#模拟一批数据
pixel_values = [PIL.Image.new('RGB', (200, 100), 'blue'),PIL.Image.new('RGB', (200, 100), 'red')
]value = [PIL.Image.new('L', (200, 100), 150),PIL.Image.new('L', (200, 100), 200)
]#试算
out = feature_extractor(pixel_values, value)
print('keys=', out.keys())
print('type=', type(out['pixel_values']), type(out['labels']))
print('len=', len(out['pixel_values']), len(out['labels']))
print('type0=', type(out['pixel_values'][0]), type(out['labels'][0]))
print('shape0=', out['pixel_values'][0].shape, out['labels'][0].shape)feature_extractor

keys= dict_keys(['pixel_values', 'labels'])
type= <class 'list'> <class 'list'>
len= 2 2
type0= <class 'numpy.ndarray'> <class 'numpy.ndarray'>
shape0= (3, 512, 512) (512, 512)
SegformerFeatureExtractor {"do_normalize": true,"do_resize": true,"feature_extractor_type": "SegformerFeatureExtractor","image_mean": [0.485,0.456,0.406],"image_std": [0.229,0.224,0.225],"reduce_labels": false,"resample": 2,"size": 512
}

from torchvision.transforms import ColorJitter#能对图像进行亮度,对比度,饱和度,色相变换的工具类。其实就是数据增强
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)print(jitter)jitter(PIL.Image.new('RGB', (200, 100), 'blue'))

ColorJitter(brightness=[0.75, 1.25], contrast=[0.75, 1.25], saturation=[0.75, 1.25], hue=[-0.1, 0.1])


加载数据

from datasets import load_dataset, load_from_disk#一个道路分类数据集
dataset = load_dataset(path='segments/sidewalk-semantic')
# dataset = load_from_disk('datas/segments/sidewalk-semantic')#把图片数据全部转换为数字
def transforms(data):pixel_values = data['pixel_values']label = data['label']#应用数据增强pixel_values = [jitter(i) for i in pixel_values]#编码图片成数字return feature_extractor(pixel_values, label)#切分训练集和测试集
dataset = dataset.shuffle(seed=1)['train'].train_test_split(test_size=0.1)dataset['train'] = dataset['train'].with_transform(transforms)print(dataset['train'][0])dataset

import torchdef collate_fn(data):pixel_values = [i['pixel_values'] for i in data]labels = [i['labels'] for i in data]pixel_values = torch.FloatTensor(pixel_values)labels = torch.LongTensor(labels)return {'pixel_values': pixel_values, 'labels': labels}loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=4,collate_fn=collate_fn,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data['pixel_values'].shape, data['labels'].shape

#因为模型的计算输出是原尺寸除以4,所以需要把结果扩张成原来的大小便于计算正确率什么的
torch.nn.functional.interpolate(torch.randn(4, 35, 128, 128),size=(512, 512),mode='bilinear',align_corners=False).shape

from transformers import SegformerForSemanticSegmentation, SegformerModel#加载模型
#一共35中道路类别,怎么来的不重要
#model = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = SegformerModel.from_pretrained('nvidia/mit-b0')self.linears = torch.nn.ModuleList([torch.nn.Linear(32, 256),torch.nn.Linear(64, 256),torch.nn.Linear(160, 256),torch.nn.Linear(256, 256)])self.classifier = torch.nn.Sequential(torch.nn.Conv2d(in_channels=1024,out_channels=256,kernel_size=1,bias=False),torch.nn.BatchNorm2d(256),torch.nn.ReLU(),torch.nn.Dropout(0.1),torch.nn.Conv2d(256, 35, kernel_size=1),)#加载预训练模型的参数parameters = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)for i in range(4):self.linears[i].load_state_dict(parameters.decode_head.linear_c[i].proj.state_dict())self.classifier[0].load_state_dict(parameters.decode_head.linear_fuse.state_dict())self.classifier[1].load_state_dict(parameters.decode_head.batch_norm.state_dict())self.classifier[4].load_state_dict(parameters.decode_head.classifier.state_dict())self.criterion = torch.nn.CrossEntropyLoss(ignore_index=255)def forward(self, pixel_values, labels):#pixel_values -> [4, 3, 512, 512]#labels -> [4, 512, 512]#首先通过预训练模型抽中间特征#[4, 32, 128, 128]#[4, 64, 64, 64]#[4, 160, 32, 32]#[4, 256, 16, 16]features = self.pretrained(pixel_values=pixel_values,output_hidden_states=True)features = features.hidden_states#打平#[4, 32, 16384]#[4, 64, 4096]#[4, 160, 1024]#[4, 256, 256]features = [i.flatten(2) for i in features]#转置,把通道放到最后一个维度#[4, 16384, 32]#[4, 4096, 64]#[4, 1024, 160]#[4, 256, 256]features = [i.transpose(1, 2) for i in features]#线性计算#[4, 16384, 256]#[4, 4096, 256]#[4, 1024, 256]#[4, 256, 256]features = [l(f) for f, l in zip(features, self.linears)]#转置回来,把通道放中间#[4, 256, 16384]#[4, 256, 4096]#[4, 256, 1024]#[4, 256, 256]features = [i.permute(0, 2, 1) for i in features]#变形成二维的图片#[4, 256, 128, 128]#[4, 256, 64, 64]#[4, 256, 32, 32]#[4, 256, 16, 16]features = [f.reshape(pixel_values.shape[0], -1, s, s)for f, s in zip(features, [128, 64, 32, 16])]#拓展到统一的尺寸#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]features = [torch.nn.functional.interpolate(i,size=(128, 128),mode='bilinear',align_corners=False)for i in features]#逆序,维度不变features = features[::-1]#在通道维度合并成一个张量#[4, 1024, 128, 128]features = torch.cat(features, dim=1)#跑分类网络,其中包括了1024->256->35两步,使用cnn网络实现#[4, 35, 128, 128]features = self.classifier(features)#为了计算loss,要把计算结果放大到和labels一致#[4, 35, 128, 128] -> [4, 35, 512, 512]#计算交叉熵lossloss = self.criterion(torch.nn.functional.interpolate(features,size=(512, 512),mode='bilinear',align_corners=False), labels)return {'loss': loss, 'logits': features}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape

from datasets import load_metric#加载评价指标
metric = load_metric('mean_iou')#试算
metric.compute(predictions=torch.ones(4, 10, 10),references=torch.ones(4, 10, 10),#一共35中道路类别,怎么来的不重要num_labels=35,#忽略背景类0ignore_index=0,reduce_labels=False)

from matplotlib import pyplot as pltdef show(image, out, label):plt.figure(figsize=(15, 5))image = image.clone()image = image.permute(1, 2, 0)image = image - image.min().item()image = image / image.max().item()image = image * 255image = PIL.Image.fromarray(image.numpy().astype('uint8'), mode='RGB')image = image.resize((512, 512))plt.subplot(1, 3, 1)plt.imshow(image)plt.axis('off')out = PIL.Image.fromarray(out.numpy().astype('uint8'))plt.subplot(1, 3, 2)plt.imshow(out)plt.axis('off')label = PIL.Image.fromarray(label.numpy().astype('uint8'))plt.subplot(1, 3, 3)plt.imshow(label)plt.axis('off')plt.show()show(data['pixel_values'][0], torch.ones(512, 512), data['labels'][0])

测试

def test():model.eval()dataset['test'] = dataset['test'].shuffle()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'].with_transform(transforms),batch_size=8,collate_fn=collate_fn,shuffle=False,drop_last=True,)labels = []outs = []correct = 0#初始化为1,防止除0total = 1for i, data in enumerate(loader_test):with torch.no_grad():out = model(**data)#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False)out = out.argmax(dim=1)outs.append(out)label = data['labels']labels.append(label)#统计正确率时排除label中的0select = label != 0correct += (label[select] == out[select]).sum().item()total += len(label[select])if i % 1 == 0:show(data['pixel_values'][0], out[0], label[0])if i == 4:break#计算评价指标metric_out = metric.compute(predictions=torch.cat(outs, dim=0),references=torch.cat(labels, dim=0),num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')print(metric_out)print(correct / total)test()

训练

from transformers import AdamW
from transformers.optimization import get_schedulerdef train():optimizer = AdamW(model.parameters(), lr=5e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader) * 3,optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 10 == 0:#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False).argmax(dim=1)label = data['labels']#计算评价指标metric_out = metric.compute(predictions=out,references=label,num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')#统计正确率时排除label中的0select = label != 0label = label[select]out = out[select]#防止除0accuracy = (label == out).sum().item() / (len(label) + 1)lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), lr, metric_out, accuracy)torch.save(model, 'models/9.抠图.model')train()

model = torch.load('models/9.抠图.model')
test()

这篇关于DL - 图像分割的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/935019

相关文章

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

c/c++的opencv图像金字塔缩放实现

《c/c++的opencv图像金字塔缩放实现》本文主要介绍了c/c++的opencv图像金字塔缩放实现,通过对原始图像进行连续的下采样或上采样操作,生成一系列不同分辨率的图像,具有一定的参考价值,感兴... 目录图像金字塔简介图像下采样 (cv::pyrDown)图像上采样 (cv::pyrUp)C++ O

Python+wxPython构建图像编辑器

《Python+wxPython构建图像编辑器》图像编辑应用是学习GUI编程和图像处理的绝佳项目,本教程中,我们将使用wxPython,一个跨平台的PythonGUI工具包,构建一个简单的... 目录引言环境设置创建主窗口加载和显示图像实现绘制工具矩形绘制箭头绘制文字绘制临时绘制处理缩放和旋转缩放旋转保存编

python+OpenCV反投影图像的实现示例详解

《python+OpenCV反投影图像的实现示例详解》:本文主要介绍python+OpenCV反投影图像的实现示例详解,本文通过实例代码图文并茂的形式给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录一、前言二、什么是反投影图像三、反投影图像的概念四、反向投影的工作原理一、利用反向投影backproj

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB

Python如何将大TXT文件分割成4KB小文件

《Python如何将大TXT文件分割成4KB小文件》处理大文本文件是程序员经常遇到的挑战,特别是当我们需要把一个几百MB甚至几个GB的TXT文件分割成小块时,下面我们来聊聊如何用Python自动完成这... 目录为什么需要分割TXT文件基础版:按行分割进阶版:精确控制文件大小完美解决方案:支持UTF-8编码

OpenCV图像形态学的实现

《OpenCV图像形态学的实现》本文主要介绍了OpenCV图像形态学的实现,包括腐蚀、膨胀、开运算、闭运算、梯度运算、顶帽运算和黑帽运算,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起... 目录一、图像形态学简介二、腐蚀(Erosion)1. 原理2. OpenCV 实现三、膨胀China编程(

C++字符串提取和分割的多种方法

《C++字符串提取和分割的多种方法》在C++编程中,字符串处理是一个常见的任务,尤其是在需要从字符串中提取特定数据时,本文将详细探讨如何使用C++标准库中的工具来提取和分割字符串,并分析不同方法的适用... 目录1. 字符串提取的基本方法1.1 使用 std::istringstream 和 >> 操作符示