tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法

2023-10-20 21:10

本文主要是介绍tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。

tf.train.Saver()

一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:

saver = tf.train.Saver()

如何保存?

import tensorflow as tf
import numpy as np# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)save_path = saver.save(sess, "my_net/save_net.ckpt")#print("Save to path: ", save_path)

这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
在这里插入图片描述
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。

保存了以后如何提取,或者说读取参数?
见以下代码:

w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# 这里不需要初始化步骤 init= tf.initialize_all_variables()saver = tf.train.Saver()
with tf.Session() as sess:# 提取变量saver.restore(sess, "my_net/save_net.ckpt")print("weights:", sess.run(w))print("biases:", sess.run(b))

这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)

这里有几点需要说明一下:

  1. Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上

这篇关于tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/249735

相关文章

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

一文详解如何在Python中从字符串中提取部分内容

《一文详解如何在Python中从字符串中提取部分内容》:本文主要介绍如何在Python中从字符串中提取部分内容的相关资料,包括使用正则表达式、Pyparsing库、AST(抽象语法树)、字符串操作... 目录前言解决方案方法一:使用正则表达式方法二:使用 Pyparsing方法三:使用 AST方法四:使用字

C#中async await异步关键字用法和异步的底层原理全解析

《C#中asyncawait异步关键字用法和异步的底层原理全解析》:本文主要介绍C#中asyncawait异步关键字用法和异步的底层原理全解析,本文给大家介绍的非常详细,对大家的学习或工作具有一... 目录C#异步编程一、异步编程基础二、异步方法的工作原理三、代码示例四、编译后的底层实现五、总结C#异步编程

python3 gunicorn配置文件的用法解读

《python3gunicorn配置文件的用法解读》:本文主要介绍python3gunicorn配置文件的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python3 gunicorn配置文件配置文件服务启动、重启、关闭启动重启关闭总结python3 gun

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@

MySQL 中的 LIMIT 语句及基本用法

《MySQL中的LIMIT语句及基本用法》LIMIT语句用于限制查询返回的行数,常用于分页查询或取部分数据,提高查询效率,:本文主要介绍MySQL中的LIMIT语句,需要的朋友可以参考下... 目录mysql 中的 LIMIT 语句1. LIMIT 语法2. LIMIT 基本用法(1) 获取前 N 行数据(

C#中DrawCurve的用法小结

《C#中DrawCurve的用法小结》本文主要介绍了C#中DrawCurve的用法小结,通常用于绘制一条平滑的曲线通过一系列给定的点,具有一定的参考价值,感兴趣的可以了解一下... 目录1. 如何使用 DrawCurve 方法(不带弯曲程度)2. 如何使用 DrawCurve 方法(带弯曲程度)3.使用Dr

C++ vector的常见用法超详细讲解

《C++vector的常见用法超详细讲解》:本文主要介绍C++vector的常见用法,包括C++中vector容器的定义、初始化方法、访问元素、常用函数及其时间复杂度,通过代码介绍的非常详细,... 目录1、vector的定义2、vector常用初始化方法1、使编程用花括号直接赋值2、使用圆括号赋值3、ve

MySQL中FIND_IN_SET函数与INSTR函数用法解析

《MySQL中FIND_IN_SET函数与INSTR函数用法解析》:本文主要介绍MySQL中FIND_IN_SET函数与INSTR函数用法解析,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一... 目录一、功能定义与语法1、FIND_IN_SET函数2、INSTR函数二、本质区别对比三、实际场景案例分

Spring Boot项目部署命令java -jar的各种参数及作用详解

《SpringBoot项目部署命令java-jar的各种参数及作用详解》:本文主要介绍SpringBoot项目部署命令java-jar的各种参数及作用的相关资料,包括设置内存大小、垃圾回收... 目录前言一、基础命令结构二、常见的 Java 命令参数1. 设置内存大小2. 配置垃圾回收器3. 配置线程栈大小