pytorch基于Unet的铁轨缺陷语义分割

2023-11-05 05:59

本文主要是介绍pytorch基于Unet的铁轨缺陷语义分割,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch基于Unet的铁轨缺陷语义分割

  • 分割效果
  • Unet网络
  • 数据读取器
  • 训练及保存模型
  • 遇到问题

分割效果

在这里插入图片描述

Unet网络

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits# if __name__ == '__main__':
#     net = UNet(n_channels=3, n_classes=1)
#     print(net)
""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

数据读取器

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
import random
import numpy as npclass ISBI_Loader(Dataset):def __init__(self, data_path, label_path):# 初始化函数,读取所有data_path下的图片self.images_path = glob.glob(os.path.join(data_path, '*.jpg'))self.labels_path = glob.glob(os.path.join(label_path, '*.jpg'))print(self.images_path)print(self.labels_path)def augment(self, image, flipCode):# 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转flip = cv2.flip(image, flipCode)return flipdef __getitem__(self, index):# 根据index读取图片image_path = self.images_path[index]# 根据image_path生成label_pathlabel_path = self.labels_path[index]# 读取训练图片和标签图片image = cv2.imread(image_path)label = cv2.imread(label_path)# image = cv2.resize(image, dsize=(1000, 160))# label = cv2.resize(label, dsize=(1000, 160))# 将数据转为单通道的图片image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)image = image.reshape(1, image.shape[0], image.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 处理标签,将像素值为255的改为1if label.max() > 1:label = label / 255# 随机进行数据增强,为2时不做处理flipCode = random.choice([-1, 0, 1, 2])if flipCode != 2:image = self.augment(image, flipCode)label = self.augment(label, flipCode)return image, labeldef __len__(self):# 返回训练集大小return len(self.images_path)if __name__ == "__main__":isbi_dataset = ISBI_Loader("Railsurfaceimages/", "GroundTruth/")print("数据个数:", len(isbi_dataset))train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,batch_size=2,shuffle=False)for image, label in train_loader:print(image.shape)

训练及保存模型

在这里插入图片描述

import torch.nn as nn
import torch
# from model import Unet
from data import ISBI_Loader
import numpy as np
from unet_model import UNetdef train_net(net, device, train_loader, epochs=40, batch_size=1, lr=0.00001):# 加载训练集train_loader = train_loader# 定义RMSprop算法optimizer = torch.optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)# 定义Loss算法criterion = nn.BCEWithLogitsLoss()# best_loss统计,初始化为正无穷best_loss = float('inf')# 训练epochs次for epoch in range(epochs):# 训练模式net.train()# 按照batch_size开始训练for image, label in train_loader:optimizer.zero_grad()# 将数据拷贝到device中image = image.to(device=device, dtype=torch.float32)label = label.to(device=device, dtype=torch.float32)# 使用网络参数,输出预测结果pred = net(image)# 计算lossloss = criterion(pred, label)print('Loss/train', loss.item())# 保存loss值最小的网络参数if loss < best_loss:best_loss = losstorch.save(net.state_dict(), 'best_model.pth')# 更新参数loss.backward()optimizer.step()if __name__ == '__main__':# 加载训练数据isbi_dataset = ISBI_Loader("Railsurfaceimages/", "GroundTruth/")train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,batch_size=2,shuffle=True)device = 'cuda'net = UNet(n_channels=1, n_classes=1)net.to(device=device)train_net(net, device, train_loader, epochs=40, batch_size=1, lr=0.00001)

遇到问题

在这里插入图片描述
padding值设为1,卷积时自动填充,不然会因尺寸为问题导致错误。

这篇关于pytorch基于Unet的铁轨缺陷语义分割的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/347882

相关文章

C#中字符串分割的多种方式

《C#中字符串分割的多种方式》在C#编程语言中,字符串处理是日常开发中不可或缺的一部分,字符串分割是处理文本数据时常用的操作,它允许我们将一个长字符串分解成多个子字符串,本文给大家介绍了C#中字符串分... 目录1. 使用 string.Split2. 使用正则表达式 (Regex.Split)3. 使用

理解分类器(linear)为什么可以做语义方向的指导?(解纠缠)

Attribute Manipulation(属性编辑)、disentanglement(解纠缠)常用的两种做法:线性探针和PCA_disentanglement和alignment-CSDN博客 在解纠缠的过程中,有一种非常简单的方法来引导G向某个方向进行生成,然后我们通过向不同的方向进行行走,那么就会得到这个属性上的图像。那么你利用多个方向进行生成,便得到了各种方向的图像,每个方向对应了很多

SAM2POINT:以zero-shot且快速的方式将任何 3D 视频分割为视频

摘要 我们介绍 SAM2POINT,这是一种采用 Segment Anything Model 2 (SAM 2) 进行零样本和快速 3D 分割的初步探索。 SAM2POINT 将任何 3D 数据解释为一系列多向视频,并利用 SAM 2 进行 3D 空间分割,无需进一步训练或 2D-3D 投影。 我们的框架支持各种提示类型,包括 3D 点、框和掩模,并且可以泛化到不同的场景,例如 3D 对象、室

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

基于YOLO8的图片实例分割系统

文章目录 在线体验快速开始一、项目介绍篇1.1 YOLO81.2 ultralytics1.3 模块介绍1.3.1 scan_task1.3.2 scan_taskflow.py1.3.3 segment_app.py 二、核心代码介绍篇2.1 segment_app.py2.2 scan_taskflow.py 三、结语 代码资源:计算机视觉领域YOLO8技术的图片实例分割实

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

如何将卷积神经网络(CNN)应用于医学图像分析:从分类到分割和检测的实用指南

引言 在现代医疗领域,医学图像已经成为疾病诊断和治疗规划的重要工具。医学图像的类型繁多,包括但不限于X射线、CT(计算机断层扫描)、MRI(磁共振成像)和超声图像。这些图像提供了对身体内部结构的详细视图,有助于医生在进行准确诊断和制定个性化治疗方案时获取关键的信息。 1. 医学图像分析的挑战 医学图像分析面临诸多挑战,其中包括: 图像数据的复杂性:医学图像通常具有高维度和复杂的结构

图像分割分析效果2

这次加了结构化损失 # 训练集dice: 0.9219 - iou: 0.8611 - loss: 0.0318 - mae: 0.0220 - total: 0.8915  # dropout后:dice: 0.9143 - iou: 0.8488 - loss: 0.0335 - mae: 0.0236 - total: 0.8816 # 加了结构化损失后:avg_score: 0.89