使用estimator结构训练tf模型

2024-08-23 14:18

本文主要是介绍使用estimator结构训练tf模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、使用estimator训练模型的流程

1、构建model_fn

def my_metric_fn(labels, predictions):return {'accuracy': tf.metrics.accuracy(labels, predictions)}def model_fn(features, labels, mode, params):""" TODO: 模型函数必须有这四个参数:param features: # 输入的特征数据:param labels: # 输入的标签数据:param mode: # train、evaluate或predict:param params: #超参数,对应Estimator传来的参数:return: TPUEstimatorSpec类型的对象"""eval_metrics=(my_metric_fn, [labels, predictions])output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, # "train" or "eval" or "predict"loss=total_loss, # double类型eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)  # None or funreturn output_spec

2、定义estimator

run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,keep_checkpoint_max=FLAGS.keep_checkpoint_max,tf_random_seed=FLAGS.random_seed,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.save_checkpoints_steps,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn,  # 模型函数config=run_config,  # 设置参数对象train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)

3、训练模型

def train_input_fn(params):batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return destimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)

4、验证模型

def eval_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

5、测试模型

def predict_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.predict(input_fn=predict_input_fn)  # type:dict
for key in sorted(result.keys()):log_info = "  %s = %s"%(key, str(result[key]))

二、使用estimator训练模型的样例

这篇关于使用estimator结构训练tf模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1099613

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

使用opencv优化图片(画面变清晰)

文章目录 需求影响照片清晰度的因素 实现降噪测试代码 锐化空间锐化Unsharp Masking频率域锐化对比测试 对比度增强常用算法对比测试 需求 对图像进行优化,使其看起来更清晰,同时保持尺寸不变,通常涉及到图像处理技术如锐化、降噪、对比度增强等 影响照片清晰度的因素 影响照片清晰度的因素有很多,主要可以从以下几个方面来分析 1. 拍摄设备 相机传感器:相机传

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

usaco 1.3 Mixing Milk (结构体排序 qsort) and hdu 2020(sort)

到了这题学会了结构体排序 于是回去修改了 1.2 milking cows 的算法~ 结构体排序核心: 1.结构体定义 struct Milk{int price;int milks;}milk[5000]; 2.自定义的比较函数,若返回值为正,qsort 函数判定a>b ;为负,a<b;为0,a==b; int milkcmp(const void *va,c

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]