本文主要是介绍tensorflow :Saver保存和提取,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
def train(args, sess, model):optimizer = tf.train.AdamOptimizer(args.learning_rate, beta1=args.momentum, name="AdamOptimizer_G").minimize(model.g_loss_all, var_list=model.c_vars)epoch = 0#saversaver = tf.train.Saver() #提取 if args.continue_training:tf.local_variables_initializer().run()last_ckpt = tf.train.latest_checkpoint(args.checkpoints_path)saver.restore(sess, last_ckpt)ckpt_name = str(last_ckpt)print ("Loaded model file from " + ckpt_name)epoch = int(ckpt_name.split('-')[-1])else:tf.global_variables_initializer().run()tf.local_variables_initializer().run()while epoch < args.train_step:pass#保存if epoch % 10 ==0:saver.save(sess, args.checkpoints_path + "/model", global_step=epoch)epoch += 1 print("Done.")
这篇关于tensorflow :Saver保存和提取的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!