【3D图像分割】基于 Pytorch 的 VNet 3D 图像分割3(3D UNet 模型篇)

2023-11-03 10:28

本文主要是介绍【3D图像分割】基于 Pytorch 的 VNet 3D 图像分割3(3D UNet 模型篇),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在本文中,主要是对3D UNet 进行一个学习和梳理。对于3D UNet 网上的资料和GitHub直接获取的代码很多,不需要自己从0开始。那么本文的目的是啥呢?

本文就是想拆解下其中的结构,看看对于一个3DUNet,和2DUNet,究竟有什么不同?如果是你自己构建,有什么样的经验和技巧可以学习。

3DUNet的论文地址:3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation

对于2DUNet感兴趣的小伙伴,可以先跳转去这里:【BraTS】Brain Tumor Segmentation 脑部肿瘤分割2(UNet的复现);相信阅读完,你会对这个模型,心中已经有了结构。

对本系列的其他篇章,点击下面👇链接:

  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割1(综述篇)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割2(基础数据流篇)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割6(数据预处理)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割7(数据预处理)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割8(CT肺实质分割)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割9(patch 的 crop 和 merge 操作)

一、 3D UNet 结构剖析

unet无论是2D,还是3D,从整体结构上进行划分,大体可以分位以下两个阶段:

  1. 下采样的阶段,也就是U的左边(encoder),负责对特征提取;
  2. 上采样的阶段,也就是U的右边(decoder),负责对预测恢复。

如下图展示的这样:
1

其中:

  • 蓝色框表示的是特征图;
  • 绿色长箭头,是concat操作;
  • 橘色三角,是conv+bn+relu的组合;
  • 红色的向下箭头,是max pool
  • 黄色的向上箭头,是up conv
  • 最后的紫色三角,是conv,恢复了最终的输出特征图;

对于模型构建这块,可以在论文中,看看作者是如何描述网络结构的:

  1. Like the standard u-net, it has an analysis and a synthesis path each with four resolution steps.
  2. In the analysis path, each layer contains two 3 × 3 × 3 convolutions each followed by a rectified linear unit (ReLu), and then a 2 × 2 × 2 max pooling with strides of two in each dimension.
  3. In the synthesis path, each layer consists of an upconvolution of 2 × 2 × 2 by strides of two in each dimension, followed by two 3 × 3 × 3 convolutions each followed by a ReLu.
  4. Shortcut connections from layers of equal resolution in the analysis path provide the essential high-resolution features to the synthesis path.
  5. In the last layer a 1×1×1 convolution reduces the number of output channels to the number of labels which is 3 in our case.

从论文中的网络结构示意图也可以发现:

  1. 水平看,每一个小块,基本都是三个特征图,最后一层除外;
  2. 水平看,每个特征图之间,都是橘色三角,是conv+bn+relu的组合,最后一层除外;
  3. encoder阶段,连接各个水平块的,是下采样;
  4. decoder阶段,连接各个水平块的,是反卷积(upconvolution);
  5. 还有就是绿色长箭头的concat,和最后的conv输出特征图。

二、 3D UNet 复现

复线在3D UNet前,可以先参照下相对简单,且很深渊源的2D UNet结构。其中被多次使用的一个水平块中,也是两个conv+bn+relu的组合,2D UNet的构建如下所示:

class ConvBlock2d(nn.Module):def __init__(self, in_ch, out_ch):super(ConvBlock2d, self).__init__()# 第1个3*3的卷积层self.conv1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),)# 第2个3*3的卷积层self.conv2 = nn.Sequential(nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),)# 定义数据前向流动形式def forward(self, x):x = self.conv1(x)x = self.conv2(x)return x

而在3D UNet的一个水平块中,同样是两个conv+bn+relu的组合,如下所示:

is_elu = False
def activateELU(is_elu, nchan):if is_elu:return nn.ELU(inplace=True)else:return nn.PReLU(nchan)def ConvBnActivate(in_channels, middle_channels, out_channels):# This is a block with 2 convolutions# The first convolution goes from in_channels to middle_channels feature maps# The second convolution goes from middle_channels to out_channels feature mapsconv = nn.Sequential(nn.Conv3d(in_channels, middle_channels, stride=1, kernel_size=3, padding=1),nn.BatchNorm3d(middle_channels),activateELU(is_elu, middle_channels),nn.Conv3d(middle_channels, out_channels, stride=1, kernel_size=3, padding=1),nn.BatchNorm3d(out_channels),activateELU(is_elu, out_channels),)return conv

可以发现,nn.Conv2d变成了nn.Conv3dnn.BatchNorm2d变成了nn.BatchNorm3d。遵照这个规则,构建下采样MaxPool3d、上采样反卷积ConvTranspose3d,以及最后紫色一层卷积,输出特征层FinalConvolution,如下:

def DownSample():# It halves the spatial dimensions on every axes (x,y,z)return nn.MaxPool3d(kernel_size=2, stride=2)def UpSample(in_channels, out_channels):# It doubles the spatial dimensions on every axes (x,y,z)return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)def FinalConvolution(in_channels, out_channels):return nn.Conv3d(in_channels, out_channels, kernel_size=1)

除此之外,绿色长箭头,concat操作,是在水平方向上,也就是列上进行组合,如下所示:

def CatBlock(x1, x2):return torch.cat((x1, x2), 1)

至此,构建模型所需要的各个组块,都准备完毕了。接下来就是构建模型,将各个组块搭起来。其中有个规律:

  • encoder中第一conv+bn+relu外,每一次前都需要下采样;
  • decoder中,每一个conv+bn+relu前,都需要上采样;
  • 并且,decoder中第一个conv操作,需要进行concat操作;
  • DownSamplechannel不变,特征图尺寸变小;
  • UpSamplechannel不变,特征图尺寸变大;

那就把这些规则,根据图示给加上,组合后的一个类,就如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass UNet3D(nn.Module):def __init__(self, num_out_classes=2, input_channels=1, init_feat_channels=32):super().__init__()# Encoder layers definitionsself.down_sample = DownSample()self.init_conv = ConvBnActivate(input_channels, init_feat_channels, init_feat_channels*2)self.down_conv1 = ConvBnActivate(init_feat_channels*2, init_feat_channels*2, init_feat_channels*4)self.down_conv2 = ConvBnActivate(init_feat_channels*4, init_feat_channels*4, init_feat_channels*8)self.down_conv3 = ConvBnActivate(init_feat_channels*8, init_feat_channels*8, init_feat_channels*16)# Decoder layers definitionsself.up_sample1 = UpSample(init_feat_channels*16, init_feat_channels*16)self.up_conv1   = ConvBnActivate(init_feat_channels*(16+8), init_feat_channels*8, init_feat_channels*8)self.up_sample2 = UpSample(init_feat_channels*8, init_feat_channels*8)self.up_conv2   = ConvBnActivate(init_feat_channels*(8+4), init_feat_channels*4, init_feat_channels*4)self.up_sample3 = UpSample(init_feat_channels*4, init_feat_channels*4)self.up_conv3   = ConvBnActivate(init_feat_channels*(4+2), init_feat_channels*2, init_feat_channels*2)self.final_conv = FinalConvolution(init_feat_channels*2, num_out_classes)# Softmaxself.softmax = F.softmaxdef forward(self, image):# Encoder Part ## B x  1 x Z x Y x Xlayer_init = self.init_conv(image)# B x 64 x Z x Y x Xmax_pool1  = self.down_sample(layer_init)# B x 64 x Z//2 x Y//2 x X//2layer_down2 = self.down_conv1(max_pool1)# B x 128 x Z//2 x Y//2 x X//2max_pool2   = self.down_sample(layer_down2)# B x 128 x Z//4 x Y//4 x X//4layer_down3 = self.down_conv2(max_pool2)# B x 256 x Z//4 x Y//4 x X//4max_pool_3  = self.down_sample(layer_down3)# B x 256 x Z//8 x Y//8 x X//8layer_down4 = self.down_conv3(max_pool_3)# B x 512 x Z//8 x Y//8 x X//8# Decoder part #layer_up1 = self.up_sample1(layer_down4)# B x 512 x Z//4 x Y//4 x X//4cat_block1 = CatBlock(layer_down3, layer_up1)# B x (256+512) x Z//4 x Y//4 x X//4layer_conv_up1 = self.up_conv1(cat_block1)# B x 256 x Z//4 x Y//4 x X//4layer_up2 = self.up_sample2(layer_conv_up1)# B x 256 x Z//2 x Y//2 x X//2cat_block2 = CatBlock(layer_down2, layer_up2)# B x (128+256) x Z//2 x Y//2 x X//2layer_conv_up2 = self.up_conv2(cat_block2)# B x 128 x Z//2 x Y//2 x X//2layer_up3 = self.up_sample3(layer_conv_up2)# B x 128 x Z x Y x Xcat_block3 = CatBlock(layer_init, layer_up3)# B x (64+128) x Z x Y x Xlayer_conv_up3 = self.up_conv3(cat_block3)# B x 64 x Z x Y x Xfinal_layer = self.final_conv(layer_conv_up3)# B x 2 x Z x Y x Xreturn self.softmax(final_layer, dim=1)

定义好了模型还不算完,分阶段测试下构建的网络是不是和我们所预想的一样。我们给他一个输入,测试下是否与我们最初的想法是一致的,是否报错等等问题,如下这样:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 没gpu就用cpu
print(DEVICE)# Tensors for 3D Image Processing in PyTorch
# Batch x Channel x Z x Y x X
# Batch size BY x Number of channels x (BY Z dim) x (BY Y dim) x (BY X dim)if __name__ == '__main__':from torchsummary import summarymodel = UNet3D(num_out_classes=3, input_channels=3, init_feat_channels=32)# print(model)summary(model, input_size=(3, 128, 128, 64), batch_size=-1, device='cpu')

打印的内容如下:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv3d-1     [-1, 32, 128, 128, 64]           2,624BatchNorm3d-2     [-1, 32, 128, 128, 64]              64PReLU-3     [-1, 32, 128, 128, 64]              32Conv3d-4     [-1, 64, 128, 128, 64]          55,360BatchNorm3d-5     [-1, 64, 128, 128, 64]             128PReLU-6     [-1, 64, 128, 128, 64]              64MaxPool3d-7       [-1, 64, 64, 64, 32]               0Conv3d-8       [-1, 64, 64, 64, 32]         110,656BatchNorm3d-9       [-1, 64, 64, 64, 32]             128PReLU-10       [-1, 64, 64, 64, 32]              64Conv3d-11      [-1, 128, 64, 64, 32]         221,312BatchNorm3d-12      [-1, 128, 64, 64, 32]             256PReLU-13      [-1, 128, 64, 64, 32]             128MaxPool3d-14      [-1, 128, 32, 32, 16]               0Conv3d-15      [-1, 128, 32, 32, 16]         442,496BatchNorm3d-16      [-1, 128, 32, 32, 16]             256PReLU-17      [-1, 128, 32, 32, 16]             128Conv3d-18      [-1, 256, 32, 32, 16]         884,992BatchNorm3d-19      [-1, 256, 32, 32, 16]             512PReLU-20      [-1, 256, 32, 32, 16]             256MaxPool3d-21       [-1, 256, 16, 16, 8]               0Conv3d-22       [-1, 256, 16, 16, 8]       1,769,728BatchNorm3d-23       [-1, 256, 16, 16, 8]             512PReLU-24       [-1, 256, 16, 16, 8]             256Conv3d-25       [-1, 512, 16, 16, 8]       3,539,456BatchNorm3d-26       [-1, 512, 16, 16, 8]           1,024PReLU-27       [-1, 512, 16, 16, 8]             512ConvTranspose3d-28      [-1, 512, 32, 32, 16]       2,097,664Conv3d-29      [-1, 256, 32, 32, 16]       5,308,672BatchNorm3d-30      [-1, 256, 32, 32, 16]             512PReLU-31      [-1, 256, 32, 32, 16]             256Conv3d-32      [-1, 256, 32, 32, 16]       1,769,728BatchNorm3d-33      [-1, 256, 32, 32, 16]             512PReLU-34      [-1, 256, 32, 32, 16]             256ConvTranspose3d-35      [-1, 256, 64, 64, 32]         524,544Conv3d-36      [-1, 128, 64, 64, 32]       1,327,232BatchNorm3d-37      [-1, 128, 64, 64, 32]             256PReLU-38      [-1, 128, 64, 64, 32]             128Conv3d-39      [-1, 128, 64, 64, 32]         442,496BatchNorm3d-40      [-1, 128, 64, 64, 32]             256PReLU-41      [-1, 128, 64, 64, 32]             128ConvTranspose3d-42    [-1, 128, 128, 128, 64]         131,200Conv3d-43     [-1, 64, 128, 128, 64]         331,840BatchNorm3d-44     [-1, 64, 128, 128, 64]             128PReLU-45     [-1, 64, 128, 128, 64]              64Conv3d-46     [-1, 64, 128, 128, 64]         110,656BatchNorm3d-47     [-1, 64, 128, 128, 64]             128PReLU-48     [-1, 64, 128, 128, 64]              64Conv3d-49      [-1, 3, 128, 128, 64]             195
================================================================
Total params: 19,077,859
Trainable params: 19,077,859
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 12.00
Forward/backward pass size (MB): 8544.00
Params size (MB): 72.78
Estimated Total Size (MB): 8628.78
----------------------------------------------------------------

其中,我们测试的参数量是19,077,859,论文中说的参数量:The architecture has 19069955 parameters in total. 有略微的差别。

后面再调用模型,进行一次前向传播,loss运算和反向回归。如果这里都通过了,那么后面构建训练代码,就更简单了很多。如下:

if __name__ == '__main__':input_channels = 3num_out_classes = 2init_feat_channels = 32batch_size = 4model = UNet3D(num_out_classes=num_out_classes, input_channels=input_channels, init_feat_channels=init_feat_channels)# B x C x Z x Y x X# 4 x 1 x 64 x 64 x 64input_batch_size = (batch_size, input_channels, 128, 128, 64)input_example = torch.rand(input_batch_size)unet = model.to(DEVICE)input_example = input_example.to(DEVICE)output = unet(input_example)# output = output.cpu().detach().numpy()# Expected output shape# B x N x Z x Y x X# 4 x 2 x 64 x 64 x 64expected_output_shape = (batch_size, num_out_classes, 128, 128, 64)print("Output shape = {}".format(output.shape))assert output.shape == expected_output_shape, "Unexpected output shape, check the architecture!"expected_gt_shape = (batch_size, 128, 128, 64)ground_truth = torch.ones(expected_gt_shape)ground_truth = ground_truth.long().to(DEVICE)# Defining loss fnce_layer = torch.nn.CrossEntropyLoss()# Calculating lossce_loss = ce_layer(output, ground_truth)print("CE Loss = {}".format(ce_loss))# Back propagationce_loss.backward()

输出内容如下:

Output shape = torch.Size([4, 2, 128, 128, 64])
CE Loss = 0.6823387145996094

一个疑问:什么时候使用softmax?什么时候使用sigmoid
答:

第二个问题:训练阶段是不是不需要softmax/sigmoid?只在推理阶段使用呢?
答:

三、总结

UNet网络的结构,无论是二维的,还是三维的,都是比较容易理解的,这可能也是为什么那么受欢迎的原因之一吧。如果你看过之前那篇关于2D UNet的过程,再看本篇应该就简单的很多。觉得本篇更简单一些呢。

我觉得本篇最大的价值,就是:

  1. 逐模块的分析了结构;
  2. 对后续的模型构建提供了思路;
  3. 构建完模型需要先预测试,两种方式可选;
  4. 对模型的优势和劣势,分析。

如果你阅读的过程中,发现了问题和疑问,欢迎评论区交流。

这篇关于【3D图像分割】基于 Pytorch 的 VNet 3D 图像分割3(3D UNet 模型篇)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/337526

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

无人叉车3d激光slam多房间建图定位异常处理方案-墙体画线地图切分方案

墙体画线地图切分方案 针对问题:墙体两侧特征混淆误匹配,导致建图和定位偏差,表现为过门跳变、外月台走歪等 ·解决思路:预期的根治方案IGICP需要较长时间完成上线,先使用切分地图的工程化方案,即墙体两侧切分为不同地图,在某一侧只使用该侧地图进行定位 方案思路 切分原理:切分地图基于关键帧位置,而非点云。 理论基础:光照是直线的,一帧点云必定只能照射到墙的一侧,无法同时照到两侧实践考虑:关

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU