使用estimator结构训练tf模型

2024-08-23 14:18

本文主要是介绍使用estimator结构训练tf模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、使用estimator训练模型的流程

1、构建model_fn

def my_metric_fn(labels, predictions):return {'accuracy': tf.metrics.accuracy(labels, predictions)}def model_fn(features, labels, mode, params):""" TODO: 模型函数必须有这四个参数:param features: # 输入的特征数据:param labels: # 输入的标签数据:param mode: # train、evaluate或predict:param params: #超参数,对应Estimator传来的参数:return: TPUEstimatorSpec类型的对象"""eval_metrics=(my_metric_fn, [labels, predictions])output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, # "train" or "eval" or "predict"loss=total_loss, # double类型eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)  # None or funreturn output_spec

2、定义estimator

run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,keep_checkpoint_max=FLAGS.keep_checkpoint_max,tf_random_seed=FLAGS.random_seed,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.save_checkpoints_steps,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn,  # 模型函数config=run_config,  # 设置参数对象train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)

3、训练模型

def train_input_fn(params):batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return destimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)

4、验证模型

def eval_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

5、测试模型

def predict_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.predict(input_fn=predict_input_fn)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

二、使用estimator训练模型的样例

这篇关于使用estimator结构训练tf模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1099613

相关文章

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

springboot security使用jwt认证方式

《springbootsecurity使用jwt认证方式》:本文主要介绍springbootsecurity使用jwt认证方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录前言代码示例依赖定义mapper定义用户信息的实体beansecurity相关的类提供登录接口测试提供一

go中空接口的具体使用

《go中空接口的具体使用》空接口是一种特殊的接口类型,它不包含任何方法,本文主要介绍了go中空接口的具体使用,具有一定的参考价值,感兴趣的可以了解一下... 目录接口-空接口1. 什么是空接口?2. 如何使用空接口?第一,第二,第三,3. 空接口几个要注意的坑坑1:坑2:坑3:接口-空接口1. 什么是空接

springboot security快速使用示例详解

《springbootsecurity快速使用示例详解》:本文主要介绍springbootsecurity快速使用示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录创www.chinasem.cn建spring boot项目生成脚手架配置依赖接口示例代码项目结构启用s

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

Spring Boot3虚拟线程的使用步骤详解

《SpringBoot3虚拟线程的使用步骤详解》虚拟线程是Java19中引入的一个新特性,旨在通过简化线程管理来提升应用程序的并发性能,:本文主要介绍SpringBoot3虚拟线程的使用步骤,... 目录问题根源分析解决方案验证验证实验实验1:未启用keep-alive实验2:启用keep-alive扩展建

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应