深度学习进阶:使用keras开发非串行化神经网络

2024-04-30 22:08

本文主要是介绍深度学习进阶:使用keras开发非串行化神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

我们当前所开发的网络都遵循同一个模式,那就是串行化。多个网络层按照前后次序折叠起来,数据从底层输入,然后从最高层输出,其结构如下图:

1.png

事实上这种形式很不灵活,在很多应用场景中不实用。有些应用场景需要网络同时接收多种输入,有些应用场景要求网络能同时又多种输出,有些需要网络内部的网络层发送分叉,像一颗多叉树那样。有一些更复杂的网络结构是,它同时接收来自不同网络的输出,试想我们想要预测二手车在市场上的售价,此时网络可能要同时接收三种类型的信息,一种是对车辆的描述,例如车的品牌,类型,使用年限,公里数等;一种是用户评价产生的文本资料;一种是车辆的图片。于是我们就可能需要如下形式的网络结构:

2.png

还有一种情况是多类型预测。给定一本小说,我们需要预测这本小说所属类型,是言情类还是历史类,同时还需要预测小说的创作年代,于是网络的输出就必须要有两个以上的分支:

3.png

对于上面问题,我们可以构造两个网络去分别预测小说的类型和创造时间,但由于这两种数据高度相关,知道小说的创作时间很有利于对小说类型的预测,因此把他们整合在一个网络结构里分析显然更为合理。同时随着神经网络应用越来越广泛,应用场景对网络结构的要求也越来越多样化,有一类网络叫Inception network,它的特点是输入数据同时由多个网络层并行处理,然后得到多个处理结果,这些处理结果最后同时归并到同一个网络层,如下图:

4.png

谷歌开发的一种强大图像处理网络就属于上面的结构类型。所有原有的串行化结构无法适应很多复杂的应用场景,因此我们必须使用新的方法构建出类似上面的多样化神经网络,好在keras导出很多API,让我们方便的构建各种类型的深度网络,我们用具体代码来看看如何构造各种形态的网络,

from keras.models import Model
from keras import layers
from keras.utils import plot_model
from keras import Inputtext_vocabulary_size = 10000
question_vocabulary_size = 1000
answer_vocabulary_size = 500text_input = Input(shape=(None, ), dtype='int32', name = 'text')
embedded_text = layers.Embedding(64, text_vocabulary_size)(text_input)
encoded_text = layers.LSTM(32)(embedded_text)question_input = Input(shape = (None, ), dtype='int32', name='question')
embedded_question = layers.Embedding(32, question_vocabulary_size)(question_input)
encoded_question = layers.LSTM(16)(embedded_question)concatenated = layers.concatenate([encoded_text, encoded_question], axis = -1)
answer = layers.Dense(answer_vocabulary_size, activation='softmax')(concatenated)
model = Model([text_input, question_input], answer)
plot_model(model, to_file='model.png', show_shapes=True)

我们无需输入数据运行训练网络,我们只要把握上面网络的拓扑结构即可,上面代码的最后一句会把网络图像绘制出来,为了代码能正确运行,我们需要安装一个插件名为graphviz,通常情况下使用如下命令安装即可:

pip install graphviz

安装插件再运行上面代码后,网络的拓扑结构会绘制在model.png图形文件里,它的结构如下所示:

model.png

我们看到该网络并非我们常见的串行结构,最上层是两个并行分支,其输出的结果在网络层concatenate_19合并后再输入最后一层dens_13。这是一个多输入单输出的网络,当我们需要构建一个网络,它能读入数据并预测多种不同类型的数值时,这类网络就是单输入多输出的情况,一个具体例子如下:

vocabulary_size = 50000
num_income_groups = 10posts_input = Input(shape=(None, ), dtype = 'int32', name = 'posts')
embedded_posts = layers.Embedding(256, vocabulary_size)(posts_input)
x = layers.Conv1D(128, 5, activation='relu')(embedded_posts)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation='relu')(x)
x = layers.Conv1D(256, 5, activation='relu')(x)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation='relu')(x)
x = layers.Conv1D(256, 5, activation='relu')(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(128, activation='relu')(x)age_prediction = layers.Dense(1, name='age')(x)
income_prediction = layers.Dense(num_income_groups, activation='softmax', name='income')(x)
gender_prediction = layers.Dense(1, activation='sigmoid', name = 'gender')(x)
model = Model(posts_input, [age_prediction, income_prediction, gender_prediction])
model.compile(optimizer='rmsprop', loss=['mse', 'categorical_crossentropy', 'binary_crossentropy'], loss_weights = [0.25, 1. , 10.])
plot_model(model, to_file='model2.png', show_shapes=True)

上面代码构建的网络用语读入个人数据,然后预测该人的年龄,收入以及性别,代码运行后,我们得到网络的拓扑图如下:

model2.png

注意到当网络有多种输出时,我们必须对每种输出定义相应的损失函数,keras会把三种输出结果加总,然后使用梯度下降法修正整个网络的参数。但是这么做会产生一种情况,如果某个分支输出误差较大,那么网络调整参数时就会更多的去照顾这个分支,从而影响其他分支结果的准确性,处理这个问题的办法是为每个输出分支设定一个权重从而影响每个分支在参数调整是所产生的影响。

更多内容,请点击进入csdn学院

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于深度学习进阶:使用keras开发非串行化神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950040

相关文章

Java中String字符串使用避坑指南

《Java中String字符串使用避坑指南》Java中的String字符串是我们日常编程中用得最多的类之一,看似简单的String使用,却隐藏着不少“坑”,如果不注意,可能会导致性能问题、意外的错误容... 目录8个避坑点如下:1. 字符串的不可变性:每次修改都创建新对象2. 使用 == 比较字符串,陷阱满

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Linux使用nload监控网络流量的方法

《Linux使用nload监控网络流量的方法》Linux中的nload命令是一个用于实时监控网络流量的工具,它提供了传入和传出流量的可视化表示,帮助用户一目了然地了解网络活动,本文给大家介绍了Linu... 目录简介安装示例用法基础用法指定网络接口限制显示特定流量类型指定刷新率设置流量速率的显示单位监控多个

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

Android 悬浮窗开发示例((动态权限请求 | 前台服务和通知 | 悬浮窗创建 )

《Android悬浮窗开发示例((动态权限请求|前台服务和通知|悬浮窗创建)》本文介绍了Android悬浮窗的实现效果,包括动态权限请求、前台服务和通知的使用,悬浮窗权限需要动态申请并引导... 目录一、悬浮窗 动态权限请求1、动态请求权限2、悬浮窗权限说明3、检查动态权限4、申请动态权限5、权限设置完毕后

C++ Primer 多维数组的使用

《C++Primer多维数组的使用》本文主要介绍了多维数组在C++语言中的定义、初始化、下标引用以及使用范围for语句处理多维数组的方法,具有一定的参考价值,感兴趣的可以了解一下... 目录多维数组多维数组的初始化多维数组的下标引用使用范围for语句处理多维数组指针和多维数组多维数组严格来说,C++语言没

在 Spring Boot 中使用 @Autowired和 @Bean注解的示例详解

《在SpringBoot中使用@Autowired和@Bean注解的示例详解》本文通过一个示例演示了如何在SpringBoot中使用@Autowired和@Bean注解进行依赖注入和Bean... 目录在 Spring Boot 中使用 @Autowired 和 @Bean 注解示例背景1. 定义 Stud