深度学习技术之加宽前馈全连接神经网络

2024-05-13 16:12

本文主要是介绍深度学习技术之加宽前馈全连接神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习技术

  • 加宽前馈全连接神经网络
    • 1. Functional API 搭建神经网络模型
      • 1.1 利用Functional API编写宽深神经网络模型进行手写数字识别
        • 1.1.1 导入需要的库
        • 1.1.2 加载虹膜(Iris)数据集
        • 1.1.3 分割训练集和测试集
        • 1.1.4 定义模型输入层
        • 1.1.5 添加隐藏层
        • 1.1.6 拼接输入层和第二个隐藏层
        • 1.1.7 添加输出层
        • 1.1.8 创建模型
        • 1.1.9 打印模型的摘要
        • 1.1.10 模型编译并训练
      • 1.2 利用Functional API编写多输入神经网络模型进行手写数字识别
        • 1.2.1 分割子集
        • 1.2.2 定义输入层
        • 1.2.3 定义全连接层
        • 1.2.4 创建模型
        • 1.2.5 编译与训练模型
        • 1.2.6 训练历史数据的可视化
    • 2. SubClassing API 搭建神经网络模型
      • 2.1 前馈全连接神经网络手写数字识别
        • 2.1.1 定义一个Keras模型类
        • 2.1.2 定义方法
        • 2.1.3 初始化模型
        • 2.1.4 通过在初始化中传递参数改变模型元素默认值
        • 2.1.5 编译与训练模型
        • 2.1.6 打印模型摘要

加宽前馈全连接神经网络

1. Functional API 搭建神经网络模型

1.1 利用Functional API编写宽深神经网络模型进行手写数字识别

1.1.1 导入需要的库

利用Sequential API建立一个顺序传播的前馈全连接神经网络,导入numpy、pandas,tensorflow等库,以及导入matplotlib的pyplot模块。从sklearn库的datasets模块中导入load_iris函数,以及从sklearn库的model_selection模块中导入train_test_split函数。从TensorFlow库中导入Keras模块。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
1.1.2 加载虹膜(Iris)数据集

虹膜(Iris)数据集是scikit-learn库中内置的一个样本数据集,它包含了150个样本,分为三个类,每个类有50个样本。这三个类分别是山鸢尾(Iris Setosa)、杂色鸢尾(Iris Versicolour)和维吉尼亚鸢尾(Iris Virginica)。

iris = load_iris()
1.1.3 分割训练集和测试集

将虹膜(Iris)数据集分割为训练集和测试集,得到训练集x_train和y_train,再将分割得到的训练集x_train和y_train分割为新的训练集和验证集。

x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target, test_size=0.2, random_state=23)
X_train, X_valid, y_train, y_valid=train_test_split(x_train, y_train,test_size=0.2, random_state=12)
1.1.4 定义模型输入层

使用X_train.shape[1:]作为输入层的形状,因为X_train.shape[0]是批量大小,通常在训练过程中改变,而X_train.shape[1:]包含了特征的数量,这些数量在训练过程中保持不变。

inputs = keras.layers.Input(shape=X_train.shape[1:])
1.1.5 添加隐藏层

隐藏层,包含神经元,并使用ReLU激活函数。

hidden1 = keras.layers.Dense(300, activation="relu")(inputs)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
1.1.6 拼接输入层和第二个隐藏层

将输入层和第二个隐藏层的输出进行拼接,得到一个融合了输入和中间层信息的特征向量。

concat = keras.layers.concatenate([inputs, hidden2])
1.1.7 添加输出层

添加了一个输出层,包含10个神经元,使用softmax激活函数,因为模型是用于多类分类任务。

output = keras.layers.Dense(10, activation="softmax")(concat)
1.1.8 创建模型

创建了一个完整的模型,将输入层和输出层连接起来,形成了一个有监督学习的模型结构。
这个模型结构结合了“宽”模型(wide model)和“深”模型(deep model)的特点,通过输入层和隐藏层的拼接来融合这两种模型。

model_fun_WideDeep = keras.models.Model(inputs=[inputs], outputs=[output])

运行结果:
在这里插入图片描述

1.1.9 打印模型的摘要
model_fun_WideDeep.summary()
1.1.10 模型编译并训练

model_fun_WideDeep.fit()方法将开始模型的训练过程,并在每个轮次结束后使用验证数据评估模型的性能。训练过程中,模型将逐渐学习如何将输入特征映射到正确的输出类别。

model_fun_WideDeep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])
h=model_fun_WideDeep.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

运行结果:
在这里插入图片描述

1.2 利用Functional API编写多输入神经网络模型进行手写数字识别

1.2.1 分割子集

将训练集X_train和验证集X_valid分割为两个子集。

X_train_A, X_train_B = X_train[:, :200], X_train[:, 100:]
X_valid_A, X_valid_B = X_valid[:, :200], X_valid[:, 100:]
1.2.2 定义输入层
input_A = keras.layers.Input(shape=X_train_A.shape[1])
input_B = keras.layers.Input(shape=X_train_B.shape[1])
1.2.3 定义全连接层
hidden1 = keras.layers.Dense(300, activation="relu")(input_B)
hidden2 = keras.layers.Dense(100, activation="relu")(hiddenl)
1.2.4 创建模型

将输入层和输出层连接起来。

model_fun_MulIn = keras.models.Model(inputs=[input_A, input_B], outputs=[output])
1.2.5 编译与训练模型

在训练过程中,模型将使用指定的损失函数和优化器来更新权重,并使用准确率作为评估指标来监控性能。

model_fun_MulIn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])

运行结果:
在这里插入图片描述

1.2.6 训练历史数据的可视化

图中显示了训练和验证集上的损失和准确率随轮次的变化情况。

pd.DataFrame(h.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()

运行结果:
在这里插入图片描述

2. SubClassing API 搭建神经网络模型

2.1 前馈全连接神经网络手写数字识别

2.1.1 定义一个Keras模型类

定义一个自定义的Keras模型类Model_sub_fnn,继承自keras.models.Model。这个类定义了一个简单的全连接神经网络,它有两个隐藏层和一个输出层。

class Model_sub_fnn(keras.models.Model):def __init__(self, units_1=300, units_2=100, units_out=10, activation='relu'):super().__init__()self.hidden1 = keras.layers.Dense(units_1, activation=activation)self.hidden2 = keras.layers.Dense(units_2, activation=activation)self.main_output = keras.layers.Dense(units_out, activation='softmax')
2.1.2 定义方法

给Model_sub_fnn类定义一个call方法。这个方法是Keras模型中的一个特殊方法,它定义了模型的前向传播过程,它将输入数据通过模型的所有层,并返回最终的输出。

def call(self, data):hidden1 = self.hidden1(data)hidden2 = self.hidden2(hidden1)main_output = self.main_output(hidden2)return main_output
2.1.3 初始化模型
model_sub_fnn = Model_sub_fnn()
2.1.4 通过在初始化中传递参数改变模型元素默认值
model_sub_fnn2 = Model_sub_fnn(300, 100, 10, activation='relu')
2.1.5 编译与训练模型

编译模型,使用训练数据和验证数据进行训练。在训练过程中,模型将使用指定的损失函数和优化器来更新权重,并使用准确率作为评估指标来监控性能。训练完成后,将得到模型的摘要,其中包含了模型的详细信息。

model_sub_fnn.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=["accuracy")
h= model_sub_fnn.fit(X_train,y_train,batch_size=32,epochs=30,validation_data = (X_valid,y_valid))

运行结果:
在这里插入图片描述

2.1.6 打印模型摘要

打印出模型的摘要,其中包括模型的层结构、每个层的输出形状、层的参数数量以及整个模型的总参数数量。

model_sub_fnn.summary()

运行结果:
在这里插入图片描述

这篇关于深度学习技术之加宽前馈全连接神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/986185

相关文章

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

Mysql 中的多表连接和连接类型详解

《Mysql中的多表连接和连接类型详解》这篇文章详细介绍了MySQL中的多表连接及其各种类型,包括内连接、左连接、右连接、全外连接、自连接和交叉连接,通过这些连接方式,可以将分散在不同表中的相关数据... 目录什么是多表连接?1. 内连接(INNER JOIN)2. 左连接(LEFT JOIN 或 LEFT

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

Spring Boot实现多数据源连接和切换的解决方案

《SpringBoot实现多数据源连接和切换的解决方案》文章介绍了在SpringBoot中实现多数据源连接和切换的几种方案,并详细描述了一个使用AbstractRoutingDataSource的实... 目录前言一、多数据源配置与切换方案二、实现步骤总结前言在 Spring Boot 中实现多数据源连接

QT实现TCP客户端自动连接

《QT实现TCP客户端自动连接》这篇文章主要为大家详细介绍了QT中一个TCP客户端自动连接的测试模型,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录版本 1:没有取消按钮 测试效果测试代码版本 2:有取消按钮测试效果测试代码版本 1:没有取消按钮 测试效果缺陷:无法手动停

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

W外链微信推广短连接怎么做?

制作微信推广链接的难点分析 一、内容创作难度 制作微信推广链接时,首先需要创作有吸引力的内容。这不仅要求内容本身有趣、有价值,还要能够激起人们的分享欲望。对于许多企业和个人来说,尤其是那些缺乏创意和写作能力的人来说,这是制作微信推广链接的一大难点。 二、精准定位难度 微信用户群体庞大,不同用户的需求和兴趣各异。因此,制作推广链接时需要精准定位目标受众,以便更有效地吸引他们点击并分享链接

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;