TensorFlow之Estimator(三)详解

2024-01-23 18:38
文章标签 详解 tensorflow estimator

本文主要是介绍TensorFlow之Estimator(三)详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

1. Estimator初识

1.1 框架结构

1.2 Estimator优势

1.3 Estimator使用步骤

1.3.1 下面通过伪代码的形式介绍如何使用Estimator:

2. 深入理解Estimator

2.1  从源代码来理解Estimator

2.2  构建model_fn

2.3  Config

2.4  什么是tf.estimator.EstimatorSpec

2.4.1 传入参数

2.4.2 不同模式需要传入不同参数


1. Estimator初识

1.1 框架结构

在介绍Estimator之前需要对它在TensorFlow这个大框架的定位有个大致的认识,如下图所示:

可以看到Estimator是属于High level的API,而Mid-level API分别是:

  • Layers:用来构建网络结构
  • Datasets: 用来构建数据读取pipeline
  • Metrics:用来评估网络性能

可以看到如果使用Estimator,我们只需要关注这三个部分即可,而不用再关心一些太细节的东西,另外也不用再使用烦人的Session了。

1.2 Estimator优势

本文档介绍了Estimator一种可极大地简化机器学习编程的高阶TensorFlow API。用了Estimator你会得到数不清的好处。

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您可以在CPU、GPU或TPU上运行基于Estimator 的模型,而无需重新编码模型
  • 使用dataset高效处理数据,搭配上Estimator再GPU或者TPU上高效的运行模型,提高整体的模型运行的时间。
  • 使用Estimator编写应用时,您必须将数据输入管道从模型中分离出来。这种分离简化了不同数据集的实验流程
  • Estimator提供安全的分布式训练循环,可以控制如何以及何时:
    • 构建图
    • 初始化变量
    • 开始排队
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 的摘要
  • Estimator简化了在模型开发者之间共享实现的过程。
  • 您可以使用高级直观代码开发先进的模型。简言之,采用Estimator创建模型通常比采用低阶TensorFlow API更简单。
  • Estimator本身在tf.layers之上构建而成,可以简化自定义过程。

1.3 Estimator使用步骤

  • 创建一个或多个输入函数,即 input_fn
  • 定义模型的特征列,即 feature_columns
  • 实例化 Estimator,指定特征列和各种超参数。
  • 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源。(train, evaluate, predict)

1.3.1 下面通过伪代码的形式介绍如何使用Estimator:

  • 创建一个或多个输入函数,即input_fn
def train_input_fn(features, labels, batch_size):"""An input function for training"""# Convert the inputs to a Dataset.dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))# Shuffle, repeat, and batch the examples.return dataset.shuffle(1000).repeat().batch(batch_size)

注意, features需要是字典 (另外此处的feature与我们常说的提取特征的feature还不太一样,也可以指原图数据(raw image),或者其他未作处理的数据)。下面定义的my_feature_column会传给Estimator用于解析features。

  • 定义模型的特征列,即feature_columns
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys(): 		    my_feature_columns.append(tf.feature_column.numeric_column(key=key))
  • 实例化 Estimator,指定特征列和各种超参数。
# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
classifier = tf.estimator.DNNClassifier(feature_columns=my_feature_columns,# Two hidden layers of 10 nodes each.hidden_units=[10, 10],# The model must choose between 3 classes.n_classes=3)

注意在实例化Estimator的时候不用把数据传进来,你只需要把feature_columns传进来即可,告诉Estimator需要解析哪些特征值,而数据集需要在训练和评估模型的时候才传。

  • 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源
    • train(训练)
# Train the Model.
classifier.train(input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),steps=args.train_steps)
  • evaluate(评估)
# Evaluate the model.
eval_result = classifier.evaluate(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
  • predict(预测)
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {'SepalLength': [5.1, 5.9, 6.9],'SepalWidth': [3.3, 3.0, 3.1],'PetalLength': [1.7, 4.2, 5.4],'PetalWidth': [0.5, 1.5, 2.1],
}predictions = classifier.predict(input_fn=lambda:iris_data.eval_input_fn(predict_x,batch_size=args.batch_size))

2. 深入理解Estimator

上面的示例中简单地介绍了Estimator,网络使用的是预创建好的DNNClassifier,其他预创建网络结构有如下:

当然在实际任务中这些网络并不能满足我们的需求,所以我们需要能够使用自定义的网络结构,那么如何实现呢?我之前看官网的教程,反正看的有点蒙,因为时不时就又蹦出一个新的参数来实现不同功能,所以就纳闷到底有多少参数可以使用?没办法只能从源代码开始啃着硬骨头(其实也不硬。。。之前只是懒)。

2.1  从源代码来理解Estimator

Estimator的源代码如下(为方便说明,已经掐头去尾):

class Estimator(object):def __init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None):...

可以看到需要传入的参数如下:

  • model_dir: 指定checkpoints和其他日志存放的路径。
  • model_fn: 这个是需要我们自定义的网络模型函数,后面详细介绍
  • config: 用于控制内部和checkpoints等,如果model_fn函数也定义config这个变量,则会将config传给model_fn
  • params: 该参数的值会传递给model_fn。
  • warm_start_from: 指定checkpoint路径,会导入该checkpoint开始训练

2.2  构建model_fn

模型函数一般定义如下:

def my_model_fn(features, 	# This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).labels,     # This is batch_labels from input_fnmode,      # An instance of tf.estimator.ModeKeysparams,  	# Additional configurationconfig=None):
  • 前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。
  • params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params['n_classes']来定义最终输出节点的个数等。
  • config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
  • mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(...)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN

model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个tf.estimator.EstimatorSpec的实例。

咋听起来可能有点不知所云,大白话版本就是:模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,mdoel_fn需要对三种模式设置三套代码。

另外model_fn需要返回什么东西呢?Estimator规定model_fn需要返回tf.estimator.EstimatorSpec,这样它才好更具一般化的进行处理。

2.3  Config

此处的config需要传入tf.estimator.RunConfig,其源代码如下:

class RunConfig(object):"""This class specifies the configurations for an `Estimator` run."""def __init__(self,model_dir=None,tf_random_seed=None,save_summary_steps=100,save_checkpoints_steps=_USE_DEFAULT,save_checkpoints_secs=_USE_DEFAULT,session_config=None,keep_checkpoint_max=5,keep_checkpoint_every_n_hours=10000,log_step_count_steps=100,train_distribute=None,device_fn=None,protocol=None,eval_distribute=None,experimental_distribute=None,experimental_max_worker_delay_secs=None,session_creation_timeout_secs=7200):
  • model_dir: 指定存储模型参数,graph等的路径
  • save_summary_steps: 每隔多少step就存一次Summaries,不知道summary是啥
  • save_checkpoints_steps:每隔多少个step就存一次checkpoint
  • save_checkpoints_secs: 每隔多少秒就存一次checkpoint,不可以和save_checkpoints_steps同时指定。如果二者都不指定,则使用默认值,即每600秒存一次。如果二者都设置为None,则不存checkpoints。

注意上面三个**save-**参数会控制保存checkpoints(模型结构和参数)和event文件(用于tensorboard),如果你都不想保存,那么你需要将这三个参数都置为FALSE

  • keep_checkpoint_max:指定最多保留多少个checkpoints,也就是说当超出指定数量后会将旧的checkpoint删除。当设置为None0时,则保留所有checkpoints。
  • keep_checkpoint_every_n_hours
  • log_step_count_steps:该参数的作用是,(相对于总的step数而言)指定每隔多少step就记录一次训练过程中loss的值,同时也会记录global steps/s,通过这个也可以得到模型训练的速度快慢。(天啦,终于找到这个参数了。。。。之前用TPU测模型速度,每次都得等好久才输出一次global steps/s的数据。。。蓝瘦香菇)

后面这些参数与分布式有关,以后有时间再慢慢了解。

  • train_distribute
  • device_fn
  • protocol
  • eval_distribute
  • experimental_distribute
  • experimental_max_worker_delay_secs

2.4  什么是tf.estimator.EstimatorSpec

2.4.1 传入参数

它是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的。其源代码如下:

class EstimatorSpec():def __new__(cls,mode,predictions=None,loss=None,train_op=None,eval_metric_ops=None,export_outputs=None,training_chief_hooks=None,training_hooks=None,scaffold=None,evaluation_hooks=None,prediction_hooks=None):

重要函数参数:

  • mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
  • predictions:Predictions Tensor or dict of Tensor.
  • loss:Training loss Tensor. Must be either scalar, or with shape [1].
  • train_op:适用于训练的步骤.
  • eval_metric_ops: Dict of metric results keyed by name.
    The values of the dict can be one of the following:
    • (1) instance of Metric class.
    • (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

其他参数的作用可参见源代码说明

2.4.2 不同模式需要传入不同参数

根据mode的值的不同,需要不同的参数,即:

  • 对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
  • 对于mode == ModeKeys.EVAL:必填字段是loss.
  • 对于mode == ModeKeys.PREDICT:必填字段是predictions.

上面的参数说明看起来还是一头雾水,下面给出例子帮助理解:

(1)最简单的情况: predict

只需要传入modepredictions

# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:predictions = {'class_ids': predicted_classes[:, tf.newaxis],'probabilities': tf.nn.softmax(logits),'logits': logits,}return tf.estimator.EstimatorSpec(mode, predictions=predictions)

(2)评估模式:eval

需要传入mode,loss,eval_metric_ops

如果调用 Estimator 的 evaluate 方法,则 model_fn 会收到 mode = ModeKeys.EVAL。在这种情况下,模型函数必须返回一个包含模型损失和一个或多个指标(可选)的 tf.estimator.EstimatorSpec。

loss示例如下:

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

TensorFlow提供了一个指标模块tf.metrics来计算常用的指标,这里以accuracy为例:

# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(labels=labels,predictions=predicted_classes,name='acc_op')

返回方式如下:

metrics = {'accuracy': accuracy}if mode == tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

(3)训练模式:train

需要传入mode,loss,train_op

loss同eval模式:

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

train_op示例:

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())

返回值:

return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

(4)通用模式

model_fn可以填充独立于模式的所有参数.在这种情况下,Estimator将忽略某些参数.在eval和infer模式中,train_op将被忽略.例子如下:

def my_model_fn(mode, features, labels):predictions = ...loss = ...train_op = ...return tf.estimator.EstimatorSpec(mode=mode,predictions=predictions,loss=loss,train_op=train_op)

这篇关于TensorFlow之Estimator(三)详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/637162

相关文章

Mysql 中的多表连接和连接类型详解

《Mysql中的多表连接和连接类型详解》这篇文章详细介绍了MySQL中的多表连接及其各种类型,包括内连接、左连接、右连接、全外连接、自连接和交叉连接,通过这些连接方式,可以将分散在不同表中的相关数据... 目录什么是多表连接?1. 内连接(INNER JOIN)2. 左连接(LEFT JOIN 或 LEFT

Java中switch-case结构的使用方法举例详解

《Java中switch-case结构的使用方法举例详解》:本文主要介绍Java中switch-case结构使用的相关资料,switch-case结构是Java中处理多个分支条件的一种有效方式,它... 目录前言一、switch-case结构的基本语法二、使用示例三、注意事项四、总结前言对于Java初学者

Linux内核之内核裁剪详解

《Linux内核之内核裁剪详解》Linux内核裁剪是通过移除不必要的功能和模块,调整配置参数来优化内核,以满足特定需求,裁剪的方法包括使用配置选项、模块化设计和优化配置参数,图形裁剪工具如makeme... 目录简介一、 裁剪的原因二、裁剪的方法三、图形裁剪工具四、操作说明五、make menuconfig

详解Java中的敏感信息处理

《详解Java中的敏感信息处理》平时开发中常常会遇到像用户的手机号、姓名、身份证等敏感信息需要处理,这篇文章主要为大家整理了一些常用的方法,希望对大家有所帮助... 目录前后端传输AES 对称加密RSA 非对称加密混合加密数据库加密MD5 + Salt/SHA + SaltAES 加密平时开发中遇到像用户的

Springboot使用RabbitMQ实现关闭超时订单(示例详解)

《Springboot使用RabbitMQ实现关闭超时订单(示例详解)》介绍了如何在SpringBoot项目中使用RabbitMQ实现订单的延时处理和超时关闭,通过配置RabbitMQ的交换机、队列和... 目录1.maven中引入rabbitmq的依赖:2.application.yml中进行rabbit

C语言线程池的常见实现方式详解

《C语言线程池的常见实现方式详解》本文介绍了如何使用C语言实现一个基本的线程池,线程池的实现包括工作线程、任务队列、任务调度、线程池的初始化、任务添加、销毁等步骤,感兴趣的朋友跟随小编一起看看吧... 目录1. 线程池的基本结构2. 线程池的实现步骤3. 线程池的核心数据结构4. 线程池的详细实现4.1 初

Python绘制土地利用和土地覆盖类型图示例详解

《Python绘制土地利用和土地覆盖类型图示例详解》本文介绍了如何使用Python绘制土地利用和土地覆盖类型图,并提供了详细的代码示例,通过安装所需的库,准备地理数据,使用geopandas和matp... 目录一、所需库的安装二、数据准备三、绘制土地利用和土地覆盖类型图四、代码解释五、其他可视化形式1.

SpringBoot使用Apache POI库读取Excel文件的操作详解

《SpringBoot使用ApachePOI库读取Excel文件的操作详解》在日常开发中,我们经常需要处理Excel文件中的数据,无论是从数据库导入数据、处理数据报表,还是批量生成数据,都可能会遇到... 目录项目背景依赖导入读取Excel模板的实现代码实现代码解析ExcelDemoInfoDTO 数据传输

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

使用Spring Cache时设置缓存键的注意事项详解

《使用SpringCache时设置缓存键的注意事项详解》在现代的Web应用中,缓存是提高系统性能和响应速度的重要手段之一,Spring框架提供了强大的缓存支持,通过​​@Cacheable​​、​​... 目录引言1. 缓存键的基本概念2. 默认缓存键生成器3. 自定义缓存键3.1 使用​​@Cacheab