关于Keras里的Sequential(序列模型)转化为Model(函数模型)的问题

2024-01-31 09:50

本文主要是介绍关于Keras里的Sequential(序列模型)转化为Model(函数模型)的问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 前言
  • 一、序列模型
  • 二、改为函数模型
    • 1.错误代码
  • 总结


前言

想在keras模型上加上注意力机制,于是把keras的序列模型转化为函数模型,结果发现参数维度不一致的问题,结果也变差了。跟踪问题后续发现是转为函数模型后,网络共享层出现了问题。

一、序列模型

该部分采用的是add添加网络层,由于存在多次重复调用相同网络层的情况,因此封装成一个自定义函数:

  def create_base_network(input_dim):seq = Sequential()seq.add(Conv2D(64, 5, activation='relu', padding='same', name='conv1', input_shape=input_dim))seq.add(Conv2D(128, 4, activation='relu', padding='same', name='conv2'))seq.add(Conv2D(256, 4, activation='relu', padding='same', name='conv3'))seq.add(Conv2D(64, 1, activation='relu', padding='same', name='conv4'))seq.add(MaxPooling2D(2, 2, name='pool1'))seq.add(Flatten(name='fla1'))seq.add(Dense(512, activation='relu', name='dense1'))seq.add(Reshape((1, 512), name='reshape'))

整体代码,该模型存在多个输入(6个):

	def create_base_network(input_dim):seq = Sequential()seq.add(Conv2D(64, 5, activation='relu', padding='same', name='conv1', input_shape=input_dim))seq.add(Conv2D(128, 4, activation='relu', padding='same', name='conv2'))seq.add(Conv2D(256, 4, activation='relu', padding='same', name='conv3'))seq.add(Conv2D(64, 1, activation='relu', padding='same', name='conv4'))seq.add(MaxPooling2D(2, 2, name='pool1'))seq.add(Flatten(name='fla1'))seq.add(Dense(512, activation='relu', name='dense1'))seq.add(Reshape((1, 512), name='reshape'))return seqbase_network = create_base_network(img_size)input_1 = Input(shape=img_size)input_2 = Input(shape=img_size)input_3 = Input(shape=img_size)input_4 = Input(shape=img_size)input_5 = Input(shape=img_size)input_6 = Input(shape=img_size)print('the shape of base1:', base_network(input_1).shape)   # (, 1, 512)out_all = Concatenate(axis=1)([base_network(input_1), base_network(input_2), base_network(input_3), base_network(input_4), base_network(input_5), base_network(input_6)])print('****', out_all.shape)   # (, 6, 512)lstm_layer = LSTM(128, name = 'lstm')(out_all)out_puts = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)model = Model([input_1,input_2,input_3,input_4,input_5,input_6], out_puts)model.summary()

网络模型:
在这里插入图片描述

二、改为函数模型

1.错误代码

第一次更改网络模型后,虽然运行未报错,但参数变多,模型性能也下降了,如下:

   def create_base_network(input_dim):x = Conv2D(64, 5, activation='relu', padding='same')(input_dim)x = Conv2D(128, 4, activation='relu', padding='same')(x)x = Conv2D(256, 4, activation='relu', padding='same')(x)x = Conv2D(64, 1, activation='relu', padding='same')(x)x = MaxPooling2D(2, 2)(x)x = Flatten()(x)x = Dense(512, activation='relu')(x)x = Reshape((1, 512))(x)return xinput_1 = Input(shape=img_size)input_2 = Input(shape=img_size)input_3 = Input(shape=img_size)input_4 = Input(shape=img_size)input_5 = Input(shape=img_size)input_6 = Input(shape=img_size)base_network_1 = create_base_network(input_1)base_network_2 = create_base_network(input_2)base_network_3 = create_base_network(input_3)base_network_4 = create_base_network(input_4)base_network_5 = create_base_network(input_5)base_network_6 = create_base_network(input_6)# print('the shape of base1:', base_network(input_1).shape)   # (, 1, 512)out_all = Concatenate(axis = 1)(  # 维度不变, 维度拼接,第一维度变为原来的6[base_network_1, base_network_2, base_network_3, base_network_4, base_network_5, base_network_6])print('****', out_all.shape)   # (, 6, 512)lstm_layer = LSTM(128, name = 'lstm')(out_all)out_puts = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)model = Model(inputs = [input_1, input_2, input_3, input_4, input_5, input_6], outputs = out_puts)  # 6个输入model.summary()

结果模型输出如下:
在这里插入图片描述
可以看到,模型的参数变为了原来的6倍多,改了很多次,后来发现,原来是因为序列模型中的base_network = create_base_network(img_size)相当于已将模型实例化成了一个model,后续调用时只传入参数,而不更改模型结构。

而改为Model API后:
base_network_1 = create_base_network(input_1)
...
base_network_6 = create_base_network(input_6)

前面定义的 def create_base_network( inputs),并未进行实例化,后续相当于创建了6次相关网络层,应该先实例化,应当改为以下部分:

# 建立网络共享层
x1 = Conv2D(64, 5, activation = 'relu', padding = 'same', name= 'conv1')
x2 = Conv2D(128, 4, activation = 'relu', padding = 'same', name = 'conv2')
x3 = Conv2D(256, 4, activation = 'relu', padding = 'same', name = 'conv3')
x4 = Conv2D(64, 1, activation = 'relu', padding = 'same', name = 'conv4')
x5 = MaxPooling2D(2, 2)
x6 = Flatten()
x7 = Dense(512, activation = 'relu')
x8 = Reshape((1, 512))input_1 = Input(shape = img_size)   # 得到6个输入
input_2 = Input(shape = img_size)
input_3 = Input(shape = img_size)
input_4 = Input(shape = img_size)
input_5 = Input(shape = img_size)
input_6 = Input(shape = img_size)base_network_1 = x8(x7(x6(x5(x4(x3(x2(x1(input_1))))))))
base_network_2 = x8(x7(x6(x5(x4(x3(x2(x1(input_2))))))))
base_network_3 = x8(x7(x6(x5(x4(x3(x2(x1(input_3))))))))
base_network_4 = x8(x7(x6(x5(x4(x3(x2(x1(input_4))))))))
base_network_5 = x8(x7(x6(x5(x4(x3(x2(x1(input_5))))))))
base_network_6 = x8(x7(x6(x5(x4(x3(x2(x1(input_6))))))))# 输入连接
out_all = Concatenate(axis = 1)(                            # 维度不变, 维度拼接,第一维度变为原来的6[base_network_1, base_network_2, base_network_3, base_network_4, base_network_5, base_network_6])# lstm layer
lstm_layer = LSTM(128, name = 'lstm3')(out_all)
# dense layer
out_layer = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)
model = Model(inputs = [input_1, input_2, input_3, input_4, input_5, input_6], outputs = out_layer)  # 6个输入
model.summary()

总结

Keras里的函数模型,如果想要多个输入共享多个网络层,
还是得将各个层实例化,不能偷懒。。。

这篇关于关于Keras里的Sequential(序列模型)转化为Model(函数模型)的问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/663352

相关文章

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.

Java 线程安全与 volatile与单例模式问题及解决方案

《Java线程安全与volatile与单例模式问题及解决方案》文章主要讲解线程安全问题的五个成因(调度随机、变量修改、非原子操作、内存可见性、指令重排序)及解决方案,强调使用volatile关键字... 目录什么是线程安全线程安全问题的产生与解决方案线程的调度是随机的多个线程对同一个变量进行修改线程的修改操

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

Redis出现中文乱码的问题及解决

《Redis出现中文乱码的问题及解决》:本文主要介绍Redis出现中文乱码的问题及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1. 问题的产生2China编程. 问题的解决redihttp://www.chinasem.cns数据进制问题的解决中文乱码问题解决总结

MySQL 中 ROW_NUMBER() 函数最佳实践

《MySQL中ROW_NUMBER()函数最佳实践》MySQL中ROW_NUMBER()函数,作为窗口函数为每行分配唯一连续序号,区别于RANK()和DENSE_RANK(),特别适合分页、去重... 目录mysql 中 ROW_NUMBER() 函数详解一、基础语法二、核心特点三、典型应用场景1. 数据分

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

Springboot如何正确使用AOP问题

《Springboot如何正确使用AOP问题》:本文主要介绍Springboot如何正确使用AOP问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录​一、AOP概念二、切点表达式​execution表达式案例三、AOP通知四、springboot中使用AOP导出