keras 实现dense prediction 逐像素标注 语义分割 像素级语义标注 pixelwise segmention labeling classification 3D数据

本文主要是介绍keras 实现dense prediction 逐像素标注 语义分割 像素级语义标注 pixelwise segmention labeling classification 3D数据,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

主要是keras的示例都是图片分类。而真正的论文代码,又太大了,不适合初学者(比如我)来学习。

所以我查找了一些资料。我在google 上捞的。

其中有个教程让人感觉很好.更完整的教程。另一个教程。

大概就是说,你的输入ground truth label需要是(width*height,class number),然后网络最后需要加个sigmoid,后面用binary_crossentrophy 损失函数。

在说白点就是图片原始标签可能是640,480,1.这样的,你先转成onehot 640,480,13(比如我有13类,一张图片有了一个三维的标注,真是fancy),然后再转成640*480,13这个二维的标注,就是保持深度,图片拉成向量。

然后最后的网络,最后一层的激活函数,要用sigmoid配binary_crossentrophy

或者是softmax 配catahorical_crossentrophy

官网说catagotical_cross rntrophy:

注意: 当使用 categorical_crossentropy 损失时,你的目标值应该是分类格式 (即,如果你有 10 个类,每个样本的目标值应该是一个 10 维的向量,这个向量除了表示类别的那个索引为 1,其他均为 0)。 为了将 整数目标值 转换为 分类目标值,你可以使用 Keras 实用函数 to_categorical
 

from keras.utils.np_utils import to_categorical categorical_labels = to_categorical(int_labels, num_classes=None)

所以,我贴一下我的代码。这个代码最终的输出是原图的1/16大小,毕竟我们只是为了说明代码,而不是真的去发paper,越简单越好。

from __future__ import print_function
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import keras
import PIL
from PIL import Image
from keras import Model, Input, optimizers
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
from keras.layers import Conv2D,Lambda,Reshape
from keras.preprocessing.image import ImageDataGenerator, load_img#数据预处理
#下面将我的label从2284*30*40*1 转成2284*1200*14的onehot编码
#2284是图片数量
#14是类别数量
#img和lab是你的图片和标注图片。
#img大小是2284*480*640*3
#lab是2284*480*640
#trainval_list是你的训练和validation数据序号列表,因为2284张图片包含了900多张测试图片,我需要筛一下
img = img./255
img_trainval = img[trainval_list, :, :, :]
mini_lab = lab[:,::16,::16]sum = np.zeros(shape=(2284, 1200, 14))
for i in range(2284):pic_lab = mini_lab[i, :, :]pic_flatten = np.reshape(pic_lab, (1, 1200))pic_onehot = keras.utils.to_categorical(pic_flatten, 14)sum[i] = pic_onehot
lab_trainval = sum[trainval_list, :, :]#网络结构是非常简单的
os.environ['CUDA_VISIBLE_DEVICES']='0'
resnet_model = resnet50.ResNet50(weights = 'imagenet', include_top=False,input_shape = (480,640,3))
layer_name = 'activation_40'
res16 = Model(inputs=resnet_model.input, outputs=resnet_model.get_layer(layer_name).output)
input_real = Input(shape=(480,640,3))
sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
x = res16(input_real)
x = Conv2D(14, (1, 1), activation='relu')(x)
sig_out = Conv2D(14,(1,1),activation = 'sigmoid')(x)
out_reshape = Reshape((1200,14))(sig_out)#配置训练参数
model_simple1 = Model(inputs=input_real, outputs=out_reshape)
model_simple1.summary()
model_simple1.compile(loss="binary_crossentropy", optimizer=sgd, metrics=['accuracy','categorical_accuracy'])
model_simple1.fit(x=img_trainval, y=lab_trainval, epochs=200, shuffle=True, batch_size=2)

训练过程:这里必须说明的是,我把未标注类也加入训练了,所以其实这个代码对于我的数据库还是需要修改的。慢慢来。先解决3D数据的问题好吧。

 

网络结构忘给了:

 warnings.warn('The output shape of `ResNet50(include_top=False)` '
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 480, 640, 3)       0         
_________________________________________________________________
model_1 (Model)              (None, 30, 40, 1024)      8589184   
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 30, 40, 14)        14350     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 30, 40, 14)        210       
_________________________________________________________________
reshape_1 (Reshape)          (None, 1200, 14)          0         
=================================================================
Total params: 8,603,744
Trainable params: 8,573,152
Non-trainable params: 30,592
_________________________________________________________________

 

这篇关于keras 实现dense prediction 逐像素标注 语义分割 像素级语义标注 pixelwise segmention labeling classification 3D数据的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1128439

相关文章

Android实现悬浮按钮功能

《Android实现悬浮按钮功能》在很多场景中,我们希望在应用或系统任意界面上都能看到一个小的“悬浮按钮”(FloatingButton),用来快速启动工具、展示未读信息或快捷操作,所以本文给大家介绍... 目录一、项目概述二、相关技术知识三、实现思路四、整合代码4.1 Java 代码(MainActivi

使用Python实现一个优雅的异步定时器

《使用Python实现一个优雅的异步定时器》在Python中实现定时器功能是一个常见需求,尤其是在需要周期性执行任务的场景下,本文给大家介绍了基于asyncio和threading模块,可扩展的异步定... 目录需求背景代码1. 单例事件循环的实现2. 事件循环的运行与关闭3. 定时器核心逻辑4. 启动与停

基于Python实现读取嵌套压缩包下文件的方法

《基于Python实现读取嵌套压缩包下文件的方法》工作中遇到的问题,需要用Python实现嵌套压缩包下文件读取,本文给大家介绍了详细的解决方法,并有相关的代码示例供大家参考,需要的朋友可以参考下... 目录思路完整代码代码优化思路打开外层zip压缩包并遍历文件:使用with zipfile.ZipFil

Python实现word文档内容智能提取以及合成

《Python实现word文档内容智能提取以及合成》这篇文章主要为大家详细介绍了如何使用Python实现从10个左右的docx文档中抽取内容,再调整语言风格后生成新的文档,感兴趣的小伙伴可以了解一下... 目录核心思路技术路径实现步骤阶段一:准备工作阶段二:内容提取 (python 脚本)阶段三:语言风格调

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

基于Java实现回调监听工具类

《基于Java实现回调监听工具类》这篇文章主要为大家详细介绍了如何基于Java实现一个回调监听工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录监听接口类 Listenable实际用法打印结果首先,会用到 函数式接口 Consumer, 通过这个可以解耦回调方法,下面先写一个

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Qt中QGroupBox控件的实现

《Qt中QGroupBox控件的实现》QGroupBox是Qt框架中一个非常有用的控件,它主要用于组织和管理一组相关的控件,本文主要介绍了Qt中QGroupBox控件的实现,具有一定的参考价值,感兴趣... 目录引言一、基本属性二、常用方法2.1 构造函数 2.2 设置标题2.3 设置复选框模式2.4 是否

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法

《springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法》:本文主要介绍springboot整合阿里云百炼DeepSeek实现sse流式打印,本文给大家介绍的非常详细,对大... 目录1.开通阿里云百炼,获取到key2.新建SpringBoot项目3.工具类4.启动类5.测试类6.测