深度学习---三好学生各成绩所占权重问题(2)

2024-01-19 06:10

本文主要是介绍深度学习---三好学生各成绩所占权重问题(2),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝

🥰 博客首页:knighthood2001

😗 欢迎点赞👍评论🗨️

❤️ 热爱python,期待与大家一同进步成长!!❤️

👀给大家推荐一款很火爆的刷题、面试求职网站👀

上文中深度学习(初识tensorflow2.版本)之三好学生成绩问题(1) 我们可以发现,搭建的神经网络已经可以运行,但显然还不能真正使用,因为它最终的计算结果是存在误差的。神经网络在投入使用前,都要经过训练的过程。那么,如何来训练神经网络呢?

目录

训练神经网络步骤步骤

代码展示

变化之处


训练神经网络步骤步骤


输入数据:例如例子中输入的x1、x2、x3,也就是两位学生各自的德育、智育、体育3项分数。

计算结果:神经网络根据输入的数据和当前的可变参数值计算出结果,本文例子中为y

计算误差:将计算出来的结果y与我们期待的结果( 或者说标准答案,把它暂时称为yTrain进行比对,看看误差(loss)是多少;在例子中,yTrain 的值也就是两位学生各自已知的总分。

调整神经网络可变参数:根据误差的大小,使用反向传播算法,对神经网络中的可变参数(也就是本章例子中的w1、w2、w3)进行相应的调节。

再次训练:在调整完可变参数后,重复上述步骤重新进行训练,直至误差低于我们的理想水平,神经网络的训练就完成了。


上篇文章编写的程序已经实现了这个流程中的前两个步骤,下面我们来实现剩余的步骤。


代码展示

import tensorflow as tftf.compat.v1.disable_eager_execution()x1 = tf.compat.v1.placeholder(dtype=tf.float32)
x2 = tf.compat.v1.placeholder(dtype=tf.float32)
x3 = tf.compat.v1.placeholder(dtype=tf.float32)# 设置标准答案
yTrain = tf.compat.v1.placeholder(dtype=tf.float32)w1 = tf.Variable(0.1, dtype=tf.float32)
w2 = tf.Variable(0.1, dtype=tf.float32)
w3 = tf.Variable(0.1, dtype=tf.float32)n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3y = n1 + n2 + n3loss = tf.abs(y - yTrain)optimizer = tf.compat.v1.train.RMSPropOptimizer(0.001)train = optimizer.minimize(loss)sess = tf.compat.v1.Session()init = tf.compat.v1.global_variables_initializer()sess.run(init)for i in range(10000):result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})print(result)result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})print(result)

变化之处

①定义了占位符yTrain,这是用来在训练时传入争对每一组输入数据我们期待的对应计算结果值的,后面一般把它简称为“目标计算结果”或“目标值”。

# 目标计算结果(目标值)
yTrain = tf.compat.v1.placeholder(dtype=tf.float32)

在计算出结果y后,我们用tf.abs(y-yTrain)来计算误差,

然后定义了一个优化器变量optimizer。所谓优化器,就是用来调整神经网络可变参数的对象。我们采用的是RMSPropOptimizer,参数0.001是这个优化器的学习率(learn rate)。所谓学习率,我们在这里可以先简单的理解为:学习率决定了优化器每次调整参数的幅度大小。

定义完优化器后,我们又定义了一个训练对象train,它代表了我们准备如何来训练这个神经网络。我们把它定义为optimizer.minimize(loss),也就是要求优化器按照把loss最小化的原则来调整可变参数。

loss = tf.abs(y - yTrain)optimizer = tf.compat.v1.train.RMSPropOptimizer(0.001)train = optimizer.minimize(loss)

接下来我们就可以进行训练了,训练的代码和之前计算的很相似。

for i in range(10000):result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})print(result)result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})print(result)

不同之处主要有两个,是在feed_dict参数中多指定一个yTrain的数值,也就是对应每一组输入数据x1,x2,x3,我们指定的目标结果值;是在sess.run函数的第一个参数,也就是我们要求输出的结果数组当中,多加了一个train对象,在结果数组中有train对象,意味着要求程序要执行train对象所包含的训练过程,那么在这个过程中,y、loss等计算结果自然也会被计算出来;所以在结果数组中即使只写一个train,其他的结果也会都被计算出来。只不过我们看不到而已。

只有在结果数组中加上了训练对象,这次sess.run函数的执行才能被称为一次“训练”,否则只是“运行”一次神经网络或者说是用神经网络进行一次“计算”。

尽管两次训练的x1,x2,x3不同,但是神经网络的训练具备适应能力,能够在训练过程中逐步调整可变参数,试图去缩小所有输入数据的计算结果误差。

我们采用for循环,来个5000轮。最后两条结果如下:

loss缩小到0.023246765-0.0332489,w1,w2,w3的数值也很接近我们期待的0.6,0.3,0.1(我们之前假设的权重)。

之后,笔者将会讲解如何优化这里的神经网络模型。

这篇关于深度学习---三好学生各成绩所占权重问题(2)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/621477

相关文章

Flask解决指定端口无法生效问题

《Flask解决指定端口无法生效问题》文章讲述了在使用PyCharm开发Flask应用时,启动地址与手动指定的IP端口不一致的问题,通过修改PyCharm的运行配置,将Flask项目的运行模式从Fla... 目录android问题重现解决方案问题重现手动指定的IP端口是app.run(host='0.0.

Seata之分布式事务问题及解决方案

《Seata之分布式事务问题及解决方案》:本文主要介绍Seata之分布式事务问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Seata–分布式事务解决方案简介同类产品对比环境搭建1.微服务2.SQL3.seata-server4.微服务配置事务模式1

mysql关联查询速度慢的问题及解决

《mysql关联查询速度慢的问题及解决》:本文主要介绍mysql关联查询速度慢的问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录mysql关联查询速度慢1. 记录原因1.1 在一次线上的服务中1.2 最终发现2. 解决方案3. 具体操作总结mysql

一文教你解决Python不支持中文路径的问题

《一文教你解决Python不支持中文路径的问题》Python是一种广泛使用的高级编程语言,然而在处理包含中文字符的文件路径时,Python有时会表现出一些不友好的行为,下面小编就来为大家介绍一下具体的... 目录问题背景解决方案1. 设置正确的文件编码2. 使用pathlib模块3. 转换路径为Unicod

Spring MVC跨域问题及解决

《SpringMVC跨域问题及解决》:本文主要介绍SpringMVC跨域问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录跨域问题不同的域同源策略解决方法1.CORS2.jsONP3.局部解决方案4.全局解决方法总结跨域问题不同的域协议、域名、端口

SpringBoot自定义注解如何解决公共字段填充问题

《SpringBoot自定义注解如何解决公共字段填充问题》本文介绍了在系统开发中,如何使用AOP切面编程实现公共字段自动填充的功能,从而简化代码,通过自定义注解和切面类,可以统一处理创建时间和修改时间... 目录1.1 问题分析1.2 实现思路1.3 代码开发1.3.1 步骤一1.3.2 步骤二1.3.3

Redis 内存淘汰策略深度解析(最新推荐)

《Redis内存淘汰策略深度解析(最新推荐)》本文详细探讨了Redis的内存淘汰策略、实现原理、适用场景及最佳实践,介绍了八种内存淘汰策略,包括noeviction、LRU、LFU、TTL、Rand... 目录一、 内存淘汰策略概述二、内存淘汰策略详解2.1 ​noeviction(不淘汰)​2.2 ​LR

基于.NET编写工具类解决JSON乱码问题

《基于.NET编写工具类解决JSON乱码问题》在开发过程中,我们经常会遇到JSON数据处理的问题,尤其是在数据传输和解析过程中,很容易出现编码错误导致的乱码问题,下面我们就来编写一个.NET工具类来解... 目录问题背景核心原理工具类实现使用示例总结在开发过程中,我们经常会遇到jsON数据处理的问题,尤其是

springboot3.4和mybatis plus的版本问题的解决

《springboot3.4和mybatisplus的版本问题的解决》本文主要介绍了springboot3.4和mybatisplus的版本问题的解决,主要由于SpringBoot3.4与MyBat... 报错1:spring-boot-starter/3.4.0/spring-boot-starter-

在 Spring Boot 中使用异步线程时的 HttpServletRequest 复用问题记录

《在SpringBoot中使用异步线程时的HttpServletRequest复用问题记录》文章讨论了在SpringBoot中使用异步线程时,由于HttpServletRequest复用导致... 目录一、问题描述:异步线程操作导致请求复用时 Cookie 解析失败1. 场景背景2. 问题根源二、问题详细分