DL - 图像分割

2024-04-25 15:04
文章标签 图像 分割 dl

本文主要是介绍DL - 图像分割,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


from transformers import SegformerFeatureExtractor
import PIL.Image#一个把图像转换为数据的工具类
feature_extractor = SegformerFeatureExtractor()#模拟一批数据
pixel_values = [PIL.Image.new('RGB', (200, 100), 'blue'),PIL.Image.new('RGB', (200, 100), 'red')
]value = [PIL.Image.new('L', (200, 100), 150),PIL.Image.new('L', (200, 100), 200)
]#试算
out = feature_extractor(pixel_values, value)
print('keys=', out.keys())
print('type=', type(out['pixel_values']), type(out['labels']))
print('len=', len(out['pixel_values']), len(out['labels']))
print('type0=', type(out['pixel_values'][0]), type(out['labels'][0]))
print('shape0=', out['pixel_values'][0].shape, out['labels'][0].shape)feature_extractor

keys= dict_keys(['pixel_values', 'labels'])
type= <class 'list'> <class 'list'>
len= 2 2
type0= <class 'numpy.ndarray'> <class 'numpy.ndarray'>
shape0= (3, 512, 512) (512, 512)
SegformerFeatureExtractor {"do_normalize": true,"do_resize": true,"feature_extractor_type": "SegformerFeatureExtractor","image_mean": [0.485,0.456,0.406],"image_std": [0.229,0.224,0.225],"reduce_labels": false,"resample": 2,"size": 512
}

from torchvision.transforms import ColorJitter#能对图像进行亮度,对比度,饱和度,色相变换的工具类。其实就是数据增强
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)print(jitter)jitter(PIL.Image.new('RGB', (200, 100), 'blue'))

ColorJitter(brightness=[0.75, 1.25], contrast=[0.75, 1.25], saturation=[0.75, 1.25], hue=[-0.1, 0.1])


加载数据

from datasets import load_dataset, load_from_disk#一个道路分类数据集
dataset = load_dataset(path='segments/sidewalk-semantic')
# dataset = load_from_disk('datas/segments/sidewalk-semantic')#把图片数据全部转换为数字
def transforms(data):pixel_values = data['pixel_values']label = data['label']#应用数据增强pixel_values = [jitter(i) for i in pixel_values]#编码图片成数字return feature_extractor(pixel_values, label)#切分训练集和测试集
dataset = dataset.shuffle(seed=1)['train'].train_test_split(test_size=0.1)dataset['train'] = dataset['train'].with_transform(transforms)print(dataset['train'][0])dataset

import torchdef collate_fn(data):pixel_values = [i['pixel_values'] for i in data]labels = [i['labels'] for i in data]pixel_values = torch.FloatTensor(pixel_values)labels = torch.LongTensor(labels)return {'pixel_values': pixel_values, 'labels': labels}loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=4,collate_fn=collate_fn,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data['pixel_values'].shape, data['labels'].shape

#因为模型的计算输出是原尺寸除以4,所以需要把结果扩张成原来的大小便于计算正确率什么的
torch.nn.functional.interpolate(torch.randn(4, 35, 128, 128),size=(512, 512),mode='bilinear',align_corners=False).shape

from transformers import SegformerForSemanticSegmentation, SegformerModel#加载模型
#一共35中道路类别,怎么来的不重要
#model = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = SegformerModel.from_pretrained('nvidia/mit-b0')self.linears = torch.nn.ModuleList([torch.nn.Linear(32, 256),torch.nn.Linear(64, 256),torch.nn.Linear(160, 256),torch.nn.Linear(256, 256)])self.classifier = torch.nn.Sequential(torch.nn.Conv2d(in_channels=1024,out_channels=256,kernel_size=1,bias=False),torch.nn.BatchNorm2d(256),torch.nn.ReLU(),torch.nn.Dropout(0.1),torch.nn.Conv2d(256, 35, kernel_size=1),)#加载预训练模型的参数parameters = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)for i in range(4):self.linears[i].load_state_dict(parameters.decode_head.linear_c[i].proj.state_dict())self.classifier[0].load_state_dict(parameters.decode_head.linear_fuse.state_dict())self.classifier[1].load_state_dict(parameters.decode_head.batch_norm.state_dict())self.classifier[4].load_state_dict(parameters.decode_head.classifier.state_dict())self.criterion = torch.nn.CrossEntropyLoss(ignore_index=255)def forward(self, pixel_values, labels):#pixel_values -> [4, 3, 512, 512]#labels -> [4, 512, 512]#首先通过预训练模型抽中间特征#[4, 32, 128, 128]#[4, 64, 64, 64]#[4, 160, 32, 32]#[4, 256, 16, 16]features = self.pretrained(pixel_values=pixel_values,output_hidden_states=True)features = features.hidden_states#打平#[4, 32, 16384]#[4, 64, 4096]#[4, 160, 1024]#[4, 256, 256]features = [i.flatten(2) for i in features]#转置,把通道放到最后一个维度#[4, 16384, 32]#[4, 4096, 64]#[4, 1024, 160]#[4, 256, 256]features = [i.transpose(1, 2) for i in features]#线性计算#[4, 16384, 256]#[4, 4096, 256]#[4, 1024, 256]#[4, 256, 256]features = [l(f) for f, l in zip(features, self.linears)]#转置回来,把通道放中间#[4, 256, 16384]#[4, 256, 4096]#[4, 256, 1024]#[4, 256, 256]features = [i.permute(0, 2, 1) for i in features]#变形成二维的图片#[4, 256, 128, 128]#[4, 256, 64, 64]#[4, 256, 32, 32]#[4, 256, 16, 16]features = [f.reshape(pixel_values.shape[0], -1, s, s)for f, s in zip(features, [128, 64, 32, 16])]#拓展到统一的尺寸#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]features = [torch.nn.functional.interpolate(i,size=(128, 128),mode='bilinear',align_corners=False)for i in features]#逆序,维度不变features = features[::-1]#在通道维度合并成一个张量#[4, 1024, 128, 128]features = torch.cat(features, dim=1)#跑分类网络,其中包括了1024->256->35两步,使用cnn网络实现#[4, 35, 128, 128]features = self.classifier(features)#为了计算loss,要把计算结果放大到和labels一致#[4, 35, 128, 128] -> [4, 35, 512, 512]#计算交叉熵lossloss = self.criterion(torch.nn.functional.interpolate(features,size=(512, 512),mode='bilinear',align_corners=False), labels)return {'loss': loss, 'logits': features}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape

from datasets import load_metric#加载评价指标
metric = load_metric('mean_iou')#试算
metric.compute(predictions=torch.ones(4, 10, 10),references=torch.ones(4, 10, 10),#一共35中道路类别,怎么来的不重要num_labels=35,#忽略背景类0ignore_index=0,reduce_labels=False)

from matplotlib import pyplot as pltdef show(image, out, label):plt.figure(figsize=(15, 5))image = image.clone()image = image.permute(1, 2, 0)image = image - image.min().item()image = image / image.max().item()image = image * 255image = PIL.Image.fromarray(image.numpy().astype('uint8'), mode='RGB')image = image.resize((512, 512))plt.subplot(1, 3, 1)plt.imshow(image)plt.axis('off')out = PIL.Image.fromarray(out.numpy().astype('uint8'))plt.subplot(1, 3, 2)plt.imshow(out)plt.axis('off')label = PIL.Image.fromarray(label.numpy().astype('uint8'))plt.subplot(1, 3, 3)plt.imshow(label)plt.axis('off')plt.show()show(data['pixel_values'][0], torch.ones(512, 512), data['labels'][0])

测试

def test():model.eval()dataset['test'] = dataset['test'].shuffle()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'].with_transform(transforms),batch_size=8,collate_fn=collate_fn,shuffle=False,drop_last=True,)labels = []outs = []correct = 0#初始化为1,防止除0total = 1for i, data in enumerate(loader_test):with torch.no_grad():out = model(**data)#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False)out = out.argmax(dim=1)outs.append(out)label = data['labels']labels.append(label)#统计正确率时排除label中的0select = label != 0correct += (label[select] == out[select]).sum().item()total += len(label[select])if i % 1 == 0:show(data['pixel_values'][0], out[0], label[0])if i == 4:break#计算评价指标metric_out = metric.compute(predictions=torch.cat(outs, dim=0),references=torch.cat(labels, dim=0),num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')print(metric_out)print(correct / total)test()

训练

from transformers import AdamW
from transformers.optimization import get_schedulerdef train():optimizer = AdamW(model.parameters(), lr=5e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader) * 3,optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 10 == 0:#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False).argmax(dim=1)label = data['labels']#计算评价指标metric_out = metric.compute(predictions=out,references=label,num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')#统计正确率时排除label中的0select = label != 0label = label[select]out = out[select]#防止除0accuracy = (label == out).sum().item() / (len(label) + 1)lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), lr, metric_out, accuracy)torch.save(model, 'models/9.抠图.model')train()

model = torch.load('models/9.抠图.model')
test()

这篇关于DL - 图像分割的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/935019

相关文章

leetcode刷题(95)——416. 分割等和子集

给定一个只包含正整数的非空数组。是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 注意: 每个数组中的元素不会超过 100 数组的大小不会超过 200 示例 1: 输入: [1, 5, 11, 5]输出: true解释: 数组可以分割成 [1, 5, 5] 和 [11]. 示例 2: 输入: [1, 2, 3, 5]输出: false解释: 数组不能分割成两个元素和相等的子

音视频开发基础知识(1)——图像基本概念

像素 **像素是图像的基本单元,一个个像素就组成了图像。你可以认为像素就是图像中的一个点。**在下面这张图中,你可以看到一个个方块,这些方块就是像素。 分辨率 图像(或视频)的分辨率是指图像的大小或尺寸。我们一般用像素个数来表示图像的尺寸。比如说一张1920x1080的图像,前者1920指的是该图像的宽度方向上有1920个像素点,而后者1080指的是图像的高 度方向上有1080个像素点。

【Python机器学习】NMF——将NMF应用于人脸图像

将NMF应用于之前用过的Wild数据集中的Labeled Faces。NMF的主要参数是我们想要提取的分量个数。通常来说,这个数字要小于输入特征的个数(否则的话,将每个像素作为单独的分量就可以对数据进行解释)。 首先,观察分类个数如何影响NMF重建数据的好坏: import mglearn.plotsimport numpy as npimport matplotlib.pyplot as

AIGC-Animate Anyone阿里的图像到视频 角色合成的框架-论文解读

Animate Anyone: Consistent and Controllable Image-to-Video Synthesis for Character Animation 论文:https://arxiv.org/pdf/2311.17117 网页:https://humanaigc.github.io/animate-anyone/ MOTIVATION 角色动画的

什么是图像频率?

经常听到图像低频成份、高频成份等等,没有细想过,今天突然一想发现真的不明白是怎么回事,在知乎上发现某答案,引用如下: 首先说说图像频率的物理意义。图像可以看做是一个定义为二维平面上的信号,该信号的幅值对应于像素的灰度(对于彩色图像则是RGB三个分量),如果我们仅仅考虑图像上某一行像素,则可以将之视为一个定义在一维空间上信号,这个信号在形式上与传统的信号处理领域的时变信号是相似的。不过是一个是

【LocalAI】(13):LocalAI最新版本支持Stable diffusion 3,20亿参数图像更加细腻了,可以继续研究下

最新版本v2.17.1 https://github.com/mudler/LocalAI/releases Stable diffusion 3 You can use Stable diffusion 3 by installing the model in the gallery (stable-diffusion-3-medium) or by placing this YAML fi

matplotlib之常见图像种类

Matplotlib 是一个用于绘制图表和数据可视化的 Python 库。它支持多种不同类型的图形,以满足各种数据可视化需求。以下是一些 Matplotlib 支持的主要图形种类: 折线图(Line Plot): 用于显示数据随时间或其他连续变量的变化趋势。特点:能够显示数据的变化趋势,反映事物的变化情况。(变化)plt.plot() 函数用于创建折线图。  示例:

细粒度图像分类论文阅读笔记

细粒度图像分类论文阅读笔记 摘要Abstract1. 用于细粒度图像分类的聚合注意力模块1.1 文献摘要1.2 研究背景1.3 本文创新点1.4 计算机视觉中的注意力机制1.5 模型方法1.5.1 聚合注意力模块1.5.2 通道注意力模块通道注意力代码实现 1.5.3 空间注意力模块空间注意力代码实现 1.5.4 CBAM注意力机制CBAM注意力代码实现 1.5.5 本文模型整体架构 1.6

DL理论笔记与理解

gradient的方向代表函数值增大的方向(这个方向由沿着各个轴方向偏导方向综合的方向),大小代表函数值变化的快慢。导数概念很大,偏导是沿着某方向上的导,梯度是沿着各个方向数偏导的向量。softmax函数叫这个的原因,把原来较大的数值压缩成相对的大数,把原来较小的数压缩在密集的空间,把数据间的margin压缩得越来越大,这就类似金字塔效应,你能力比别人强一些,得到的收益可能比别人强太多。CNN中卷

动手学深度学习(Pytorch版)代码实践 -计算机视觉-36图像增广

6 图片增广 import matplotlib.pyplot as pltimport numpy as npimport torch import torchvisionfrom d2l import torch as d2lfrom torch import nn from PIL import Imageimport liliPytorch as lpfrom tor