1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型

本文主要是介绍1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

#coding:utf-8
'''
a liner regression by tenosrflow.
input dimension: 1, output dimension: 1.
显示每个epoch的loss
利用模型预测
保存模型
载入模型
打印模型中的参数
修改模型中的参数
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file# data
x_train = np.linspace(-1, 1, 100)
y_train = 10 * x_train + np.random.randn(x_train.shape[0])
# plt.plot(x_train, y_train, "ro", label="data")
# plt.legend()
# plt.show()epochs = 30
display_step = 2
# input, output
x = tf.placeholder(dtype="float", name="input")
y = tf.placeholder(dtype="float", name="label")
# w, b
w = tf.Variable(initial_value=tf.random_normal([1]), name="weight")
b = tf.Variable(initial_value=tf.zeros([1]), name="bias")
# model
z = tf.multiply(x, w) + b
# loss functon
cost = tf.reduce_mean(tf.square(y - z))
# optimizer
optim = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)
saver = tf.train.Saver(max_to_keep=4)  # save 4 model
init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)for epoch in range(epochs):for x_batch, y_batch in zip(x_train, y_train):  # batch is all data theresess.run(optim, feed_dict={x:x_batch, y:y_batch})if epoch % display_step ==0:loss = sess.run(cost, feed_dict={x:x_train, y:y_train})print("epoch: %d, loss: %d" %(epoch, loss))# 保存训练过程中的模型saver.save(sess, "line_regression_model/regress.cpkt", global_step=epoch)print("train finished...")# 保存最终的模型saver.save(sess, "line_regression_model/regress.cpkt")print("final loss:", sess.run(cost, feed_dict={x:x_train, y:y_train}))print("weight:", sess.run(w))print("bias:", sess.run(b))# show train data and predict dataplt.plot(x_train, y_train, "ro", label="train")predict = sess.run(w) * x_train + sess.run(b)plt.plot(x_train, predict, "b", label="predict")plt.legend()plt.show()# 载入模型
print("*"*50)
saver = tf.train.Saver()
with tf.Session() as sess2:sess2.run(tf.global_variables_initializer())saver.restore(sess2, "line_regression_model/regress.cpkt")print(sess2.run(w))print(sess2.run(b))predict2 = sess2.run(z, feed_dict={x:0.5})print(predict2)# 打印出模型中的变量及参数
print("-"*50)
print("the params in model:")
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)# 修改模型中的参数,并重新保存
print("-"*50)
# 以上得到了模型中参数名字为weight,bias, 下面对他们进行修改
w_change = tf.Variable(10, name="weight")
b_change = tf.Variable(0.001, name="bias")
# 把他们放到一个字典里并写在saver里
saver = tf.train.Saver({"weighs":w_change, "bias":b_change})
with tf.Session() as sess3:sess3.run(tf.global_variables_initializer())# 保存修改后的参数saver.save(sess3, "line_regression_model/regress.cpkt")
# 发现参数已经被修改
print_tensors_in_checkpoint_file("line_regression_model/regress.cpkt", None, True)

输出:

/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/liner_regression.py
2018-11-17 16:07:32.138907: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
epoch: 0, loss: 21
epoch: 2, loss: 2
epoch: 4, loss: 1
epoch: 6, loss: 1
epoch: 8, loss: 1
epoch: 10, loss: 1
epoch: 12, loss: 1
epoch: 14, loss: 1
epoch: 16, loss: 1
epoch: 18, loss: 1
epoch: 20, loss: 1
epoch: 22, loss: 1
epoch: 24, loss: 1
epoch: 26, loss: 1
epoch: 28, loss: 1
train finished...
('final loss:', 1.0535882)
('weight:', array([10.063329], dtype=float32))
('bias:', array([0.03052005], dtype=float32))
**************************************************
[10.063329]
[0.03052005]
[5.0621843]
--------------------------------------------------
the params in model:
tensor_name:  bias
[0.03052005]
tensor_name:  weight
[10.063329]
--------------------------------------------------
tensor_name:  bias
0.001
tensor_name:  weighs
10
 

这篇关于1.tensorflow线性回归示例:保存模型,载入模型,打印模型参数,修改模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/986052

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

Android 悬浮窗开发示例((动态权限请求 | 前台服务和通知 | 悬浮窗创建 )

《Android悬浮窗开发示例((动态权限请求|前台服务和通知|悬浮窗创建)》本文介绍了Android悬浮窗的实现效果,包括动态权限请求、前台服务和通知的使用,悬浮窗权限需要动态申请并引导... 目录一、悬浮窗 动态权限请求1、动态请求权限2、悬浮窗权限说明3、检查动态权限4、申请动态权限5、权限设置完毕后

在 Spring Boot 中使用 @Autowired和 @Bean注解的示例详解

《在SpringBoot中使用@Autowired和@Bean注解的示例详解》本文通过一个示例演示了如何在SpringBoot中使用@Autowired和@Bean注解进行依赖注入和Bean... 目录在 Spring Boot 中使用 @Autowired 和 @Bean 注解示例背景1. 定义 Stud

oracle DBMS_SQL.PARSE的使用方法和示例

《oracleDBMS_SQL.PARSE的使用方法和示例》DBMS_SQL是Oracle数据库中的一个强大包,用于动态构建和执行SQL语句,DBMS_SQL.PARSE过程解析SQL语句或PL/S... 目录语法示例注意事项DBMS_SQL 是 oracle 数据库中的一个强大包,它允许动态地构建和执行

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Python中顺序结构和循环结构示例代码

《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

Python中Markdown库的使用示例详解

《Python中Markdown库的使用示例详解》Markdown库是一个用于处理Markdown文本的Python工具,这篇文章主要为大家详细介绍了Markdown库的具体使用,感兴趣的... 目录一、背景二、什么是 Markdown 库三、如何安装这个库四、库函数使用方法1. markdown.mark

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意