学习TensorFlow,保存学习到的网络结构参数并调用

2024-05-07 16:08

本文主要是介绍学习TensorFlow,保存学习到的网络结构参数并调用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在深度学习中,不管使用那种学习框架,我们会遇到一个很重要的问题,那就是在训练完之后,如何存储学习到的深度网络的参数?在测试时,如何调用这些网络参数?针对这两个问题,本篇博文主要探索TensorFlow如何解决他们?本篇博文分为三个部分,第一是讲解tensorflow相关的函数,第二是代码例程,第三是运行结果。

一 tensorflow相关的函数

我们说的这两个功能主要由一个类来完成,class tf.train.Saver

[plain] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. saver = tf.train.Saver()  
  2. save_path = saver.save(sess, model_path)  
  3. load_path = saver.restore(sess, model_path)  
saver = tf.train.Saver() 由类创建对象saver,用于保存和调用学习到的网络参数,参数保存在checkpoints里

save_path = saver.save(sess, model_path) 保存学习到的网络参数到model_path路径中

load_path = saver.restore(sess, model_path) 调用model_path路径中的保存的网络参数到graph中


二 代码例程

[python] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. ''''' 
  2. Save and Restore a model using TensorFlow. 
  3. This example is using the MNIST database of handwritten digits 
  4. (http://yann.lecun.com/exdb/mnist/) 
  5.  
  6. Author: Aymeric Damien 
  7. Project: https://github.com/aymericdamien/TensorFlow-Examples/ 
  8. '''  
  9.   
  10. # Import MINST data  
  11. from tensorflow.examples.tutorials.mnist import input_data  
  12. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)  
  13.   
  14. import tensorflow as tf  
  15.   
  16. # Parameters  
  17. learning_rate = 0.001  
  18. batch_size = 100  
  19. display_step = 1  
  20. model_path = "/home/lei/TensorFlow-Examples-master/examples/4_Utils/model.ckpt"  
  21.   
  22. # Network Parameters  
  23. n_hidden_1 = 256 # 1st layer number of features  
  24. n_hidden_2 = 256 # 2nd layer number of features  
  25. n_input = 784 # MNIST data input (img shape: 28*28)  
  26. n_classes = 10 # MNIST total classes (0-9 digits)  
  27.   
  28. # tf Graph input  
  29. x = tf.placeholder("float", [None, n_input])  
  30. y = tf.placeholder("float", [None, n_classes])  
  31.   
  32.   
  33. # Create model  
  34. def multilayer_perceptron(x, weights, biases):  
  35.     # Hidden layer with RELU activation  
  36.     layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])  
  37.     layer_1 = tf.nn.relu(layer_1)  
  38.     # Hidden layer with RELU activation  
  39.     layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])  
  40.     layer_2 = tf.nn.relu(layer_2)  
  41.     # Output layer with linear activation  
  42.     out_layer = tf.matmul(layer_2, weights['out']) + biases['out']  
  43.     return out_layer  
  44.   
  45. # Store layers weight & bias  
  46. weights = {  
  47.     'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),  
  48.     'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),  
  49.     'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))  
  50. }  
  51. biases = {  
  52.     'b1': tf.Variable(tf.random_normal([n_hidden_1])),  
  53.     'b2': tf.Variable(tf.random_normal([n_hidden_2])),  
  54.     'out': tf.Variable(tf.random_normal([n_classes]))  
  55. }  
  56.   
  57. # Construct model  
  58. pred = multilayer_perceptron(x, weights, biases)  
  59.   
  60. # Define loss and optimizer  
  61. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))  
  62. optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)  
  63.   
  64. # Initializing the variables  
  65. init = tf.initialize_all_variables()  
  66.   
  67. # 'Saver' op to save and restore all the variables  
  68. saver = tf.train.Saver()  
  69.   
  70. # Running first session  
  71. print "Starting 1st session..."  
  72. with tf.Session() as sess:  
  73.     # Initialize variables  
  74.     sess.run(init)  
  75.   
  76.     # Training cycle  
  77.     for epoch in range(3):  
  78.         avg_cost = 0.  
  79.         total_batch = int(mnist.train.num_examples/batch_size)  
  80.         # Loop over all batches  
  81.         for i in range(total_batch):  
  82.             batch_x, batch_y = mnist.train.next_batch(batch_size)  
  83.             # Run optimization op (backprop) and cost op (to get loss value)  
  84.             _, c = sess.run([optimizer, cost], feed_dict={x: batch_x,  
  85.                                                           y: batch_y})  
  86.             # Compute average loss  
  87.             avg_cost += c / total_batch  
  88.         # Display logs per epoch step  
  89.         if epoch % display_step == 0:  
  90.             print "Epoch:"'%04d' % (epoch+1), "cost=", \  
  91.                 "{:.9f}".format(avg_cost)  
  92.     print "First Optimization Finished!"  
  93.   
  94.     # Test model  
  95.     correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))  
  96.     # Calculate accuracy  
  97.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
  98.     print "Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})  
  99.   
  100.     # Save model weights to disk  
  101.     save_path = saver.save(sess, model_path)  
  102.     print "Model saved in file: %s" % save_path  
  103.   
  104. # Running a new session  
  105. print "Starting 2nd session..."  
  106. with tf.Session() as sess:  
  107.     # Initialize variables  
  108.     sess.run(init)  
  109.   
  110.     # Restore model weights from previously saved model  
  111.     load_path = saver.restore(sess, model_path)  
  112.     print "Model restored from file: %s" % save_path  
  113.   
  114.     # Resume training  
  115.     for epoch in range(7):  
  116.         avg_cost = 0.  
  117.         total_batch = int(mnist.train.num_examples / batch_size)  
  118.         # Loop over all batches  
  119.         for i in range(total_batch):  
  120.             batch_x, batch_y = mnist.train.next_batch(batch_size)  
  121.             # Run optimization op (backprop) and cost op (to get loss value)  
  122.             _, c = sess.run([optimizer, cost], feed_dict={x: batch_x,  
  123.                                                           y: batch_y})  
  124.             # Compute average loss  
  125.             avg_cost += c / total_batch  
  126.         # Display logs per epoch step  
  127.         if epoch % display_step == 0:  
  128.             print "Epoch:"'%04d' % (epoch + 1), "cost=", \  
  129.                 "{:.9f}".format(avg_cost)  
  130.     print "Second Optimization Finished!"  
  131.   
  132.     # Test model  
  133.     correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))  
  134.     # Calculate accuracy  
  135.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
  136.     print "Accuracy:", accuracy.eval(  
  137.         {x: mnist.test.images, y: mnist.test.labels})  


三 运行结果



参考资料:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

https://www.tensorflow.org/versions/r0.9/api_docs/python/state_ops.html#Saver

这篇关于学习TensorFlow,保存学习到的网络结构参数并调用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/967818

相关文章

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

一分钟带你上手Python调用DeepSeek的API

《一分钟带你上手Python调用DeepSeek的API》最近DeepSeek非常火,作为一枚对前言技术非常关注的程序员来说,自然都想对接DeepSeek的API来体验一把,下面小编就来为大家介绍一下... 目录前言免费体验API-Key申请首次调用API基本概念最小单元推理模型智能体自定义界面总结前言最

Java通过反射获取方法参数名的方式小结

《Java通过反射获取方法参数名的方式小结》这篇文章主要为大家详细介绍了Java如何通过反射获取方法参数名的方式,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、前言2、解决方式方式2.1: 添加编译参数配置 -parameters方式2.2: 使用Spring的内部工具类 -

使用C++将处理后的信号保存为PNG和TIFF格式

《使用C++将处理后的信号保存为PNG和TIFF格式》在信号处理领域,我们常常需要将处理结果以图像的形式保存下来,方便后续分析和展示,C++提供了多种库来处理图像数据,本文将介绍如何使用stb_ima... 目录1. PNG格式保存使用stb_imagephp_write库1.1 安装和包含库1.2 代码解

JAVA调用Deepseek的api完成基本对话简单代码示例

《JAVA调用Deepseek的api完成基本对话简单代码示例》:本文主要介绍JAVA调用Deepseek的api完成基本对话的相关资料,文中详细讲解了如何获取DeepSeekAPI密钥、添加H... 获取API密钥首先,从DeepSeek平台获取API密钥,用于身份验证。添加HTTP客户端依赖使用Jav

redis防止短信恶意调用的实现

《redis防止短信恶意调用的实现》本文主要介绍了在场景登录或注册接口中使用短信验证码时遇到的恶意调用问题,并通过使用Redis分布式锁来解决,具有一定的参考价值,感兴趣的可以了解一下... 目录1.场景2.排查3.解决方案3.1 Redis锁实现3.2 方法调用1.场景登录或注册接口中,使用短信验证码场

使用C/C++调用libcurl调试消息的方式

《使用C/C++调用libcurl调试消息的方式》在使用C/C++调用libcurl进行HTTP请求时,有时我们需要查看请求的/应答消息的内容(包括请求头和请求体)以方便调试,libcurl提供了多种... 目录1. libcurl 调试工具简介2. 输出请求消息使用 CURLOPT_VERBOSE使用 C

vscode保存代码时自动eslint格式化图文教程

《vscode保存代码时自动eslint格式化图文教程》:本文主要介绍vscode保存代码时自动eslint格式化的相关资料,包括打开设置文件并复制特定内容,文中通过代码介绍的非常详细,需要的朋友... 目录1、点击设置2、选择远程--->点击右上角打开设置3、会弹出settings.json文件,将以下内

Python调用另一个py文件并传递参数常见的方法及其应用场景

《Python调用另一个py文件并传递参数常见的方法及其应用场景》:本文主要介绍在Python中调用另一个py文件并传递参数的几种常见方法,包括使用import语句、exec函数、subproce... 目录前言1. 使用import语句1.1 基本用法1.2 导入特定函数1.3 处理文件路径2. 使用ex