基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码

本文主要是介绍基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  第一步:准备数据

头发分割数据,总共有5711张图片,里面的像素值为0和1,所以看起来全部是黑的,不影响使用

第二步:搭建模型

计算机视觉领域有很多任务是位置敏感的,比如目标检测、语义分割、实例分割等等。为了这些任务位置信息更加精准,很容易想到的做法就是维持高分辨率的feature map,事实上HRNet之前几乎所有的网络都是这么做的,通过下采样得到强语义信息,然后再上采样恢复高分辨率恢复位置信息(如下图所示),然而这种做法,会导致大量的有效信息在不断的上下采样过程中丢失。而HRNet通过并行多个分辨率的分支,加上不断进行不同分支之间的信息交互,同时达到强语义信息和精准位置信息的目的。

recover high resolution

思路在当时来讲,不同分支的信息交互属于很老套的思路(如FPN等),我觉得最大的创新点还是能够从头到尾保持高分辨率,而不同分支的信息交互是为了补充通道数减少带来的信息损耗,这种网络架构设计对于位置敏感的任务会有奇效。

第三步:代码

1)损失函数为:交叉熵损失函数

2)网络代码:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .backbone import BN_MOMENTUM, hrnet_classificationclass HRnet_Backbone(nn.Module):def __init__(self, backbone = 'hrnetv2_w18', pretrained = False):super(HRnet_Backbone, self).__init__()self.model    = hrnet_classification(backbone = backbone, pretrained = pretrained)del self.model.incre_modulesdel self.model.downsamp_modulesdel self.model.final_layerdel self.model.classifierdef forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.conv2(x)x = self.model.bn2(x)x = self.model.relu(x)x = self.model.layer1(x)x_list = []for i in range(2):if self.model.transition1[i] is not None:x_list.append(self.model.transition1[i](x))else:x_list.append(x)y_list = self.model.stage2(x_list)x_list = []for i in range(3):if self.model.transition2[i] is not None:if i < 2:x_list.append(self.model.transition2[i](y_list[i]))else:x_list.append(self.model.transition2[i](y_list[-1]))else:x_list.append(y_list[i])y_list = self.model.stage3(x_list)x_list = []for i in range(4):if self.model.transition3[i] is not None:if i < 3:x_list.append(self.model.transition3[i](y_list[i]))else:x_list.append(self.model.transition3[i](y_list[-1]))else:x_list.append(y_list[i])y_list = self.model.stage4(x_list)return y_listclass HRnet(nn.Module):def __init__(self, num_classes = 21, backbone = 'hrnetv2_w18', pretrained = False):super(HRnet, self).__init__()self.backbone       = HRnet_Backbone(backbone = backbone, pretrained = pretrained)last_inp_channels   = np.int(np.sum(self.backbone.model.pre_stage_channels))self.last_layer = nn.Sequential(nn.Conv2d(in_channels=last_inp_channels, out_channels=last_inp_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(in_channels=last_inp_channels, out_channels=num_classes, kernel_size=1, stride=1, padding=0))def forward(self, inputs):H, W = inputs.size(2), inputs.size(3)x = self.backbone(inputs)# Upsamplingx0_h, x0_w = x[0].size(2), x[0].size(3)x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x = torch.cat([x[0], x1, x2, x3], 1)x = self.last_layer(x)x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)return x

第四步:统计一些指标(训练过程中的loss和miou)

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码

有问题可以私信或者留言,有问必答

这篇关于基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1124401

相关文章

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Python GUI框架中的PyQt详解

《PythonGUI框架中的PyQt详解》PyQt是Python语言中最强大且广泛应用的GUI框架之一,基于Qt库的Python绑定实现,本文将深入解析PyQt的核心模块,并通过代码示例展示其应用场... 目录一、PyQt核心模块概览二、核心模块详解与示例1. QtCore - 核心基础模块2. QtWid

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

Linux系统之dns域名解析全过程

《Linux系统之dns域名解析全过程》:本文主要介绍Linux系统之dns域名解析全过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、dns域名解析介绍1、DNS核心概念1.1 区域 zone1.2 记录 record二、DNS服务的配置1、正向解析的配置

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、