Tensorflow_tf1简单实现小批量梯度下降

2023-11-02 14:38

本文主要是介绍Tensorflow_tf1简单实现小批量梯度下降,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow 程序通常分为两部分:
第一部分构建计算图谱(这称为构造阶段),
第二部分运行它(这是执行阶段)。

建设阶段通常构建一个表示ML模型的计算图谱,然后对其进行训练计算。执行
阶段通常运行循环,重复地求出训练步骤,逐渐改进模型参数

0、加载包和数据

## 加载包
import numpy as np 
from sklearn.datasets import fetch_california_housing
import tensorflow as tf 
from sklearn.preprocessing import StandardScaler## 加载数据
housing = fetch_california_housing()
m, n = housing.data.shape
scl = StandardScaler()
housing_scl = scl.fit_transform(housing.data)
housing_sc_bias = np.c_[np.ones((m, 1)), housing_scl]

1、构建计算图谱

# 用占位符,不对x的行数做限制
x = tf.placeholder(tf.float32, shape = (None, n + 1), name = 'x')
y = tf.placeholder(tf.float32, shape = (None,  1), name = 'y')
# 给theta 随机初始值
theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0, seed = 42), name='theta')# 计算误差
y_prd = tf.matmul(x, theta, name = 'predictions')
error = y_prd - y
# 类似 from functools import reduce
#reduce(lambda x1, x2: x1 + x2 ,[1,2,3])
mse = tf.reduce_mean(tf.square(error), name = 'mse')
# 梯度下降优化器
learn_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learn_rate)
training_op = optimizer.minimize(mse)

2、执行阶段

init = tf.global_variables_initializer() # 初始化变量
n_epochs = 10   
batch_size = 100 
n_batches = np.int(np.ceil(m / batch_size))def fetch_batch(epoch, batch_index, batch_size):# 随机获取小批量数据np.random.seed(epoch * n_batches + batch_index) indices = np.random.randint(m, size = batch_size)return housing_sc_bias[indices] , housing.target.reshape(-1, 1)[indices] with tf.Session() as sess:# 初始化变量sess.run(init)for epoch in range(n_epochs):  # 总共循环次数for batch_index in range(n_batches):x_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)# 数据导入 类似于 sklearn.class.fitsess.run(training_op, feed_dict = {x : x_batch, y : y_batch})best_theta = theta.eval()print(best_theta)

这篇关于Tensorflow_tf1简单实现小批量梯度下降的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/331286

相关文章

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

利用python实现对excel文件进行加密

《利用python实现对excel文件进行加密》由于文件内容的私密性,需要对Excel文件进行加密,保护文件以免给第三方看到,本文将以Python语言为例,和大家讲讲如何对Excel文件进行加密,感兴... 目录前言方法一:使用pywin32库(仅限Windows)方法二:使用msoffcrypto-too

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

SpringBoot整合mybatisPlus实现批量插入并获取ID详解

《SpringBoot整合mybatisPlus实现批量插入并获取ID详解》这篇文章主要为大家详细介绍了SpringBoot如何整合mybatisPlus实现批量插入并获取ID,文中的示例代码讲解详细... 目录【1】saveBATch(一万条数据总耗时:2478ms)【2】集合方式foreach(一万条数

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

PyQt6/PySide6中QTableView类的实现

《PyQt6/PySide6中QTableView类的实现》本文主要介绍了PyQt6/PySide6中QTableView类的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学... 目录1. 基本概念2. 创建 QTableView 实例3. QTableView 的常用属性和方法

PyQt6/PySide6中QTreeView类的实现

《PyQt6/PySide6中QTreeView类的实现》QTreeView是PyQt6或PySide6库中用于显示分层数据的控件,本文主要介绍了PyQt6/PySide6中QTreeView类的实现... 目录1. 基本概念2. 创建 QTreeView 实例3. QTreeView 的常用属性和方法属性

Android使用ImageView.ScaleType实现图片的缩放与裁剪功能

《Android使用ImageView.ScaleType实现图片的缩放与裁剪功能》ImageView是最常用的控件之一,它用于展示各种类型的图片,为了能够根据需求调整图片的显示效果,Android提... 目录什么是 ImageView.ScaleType?FIT_XYFIT_STARTFIT_CENTE

pandas中位数填充空值的实现示例

《pandas中位数填充空值的实现示例》中位数填充是一种简单而有效的方法,用于填充数据集中缺失的值,本文就来介绍一下pandas中位数填充空值的实现,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是中位数填充?为什么选择中位数填充?示例数据结果分析完整代码总结在数据分析和机器学习过程中,处理缺失数