政安晨:示例演绎TensorFlow的官方指南(二){Estimator}

2024-02-07 21:20

本文主要是介绍政安晨:示例演绎TensorFlow的官方指南(二){Estimator},希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

咱们接着演绎TensorFlow官方指南,我的这个系列的上一篇文章为:

政安晨:示例演绎TensorFlow的官方指南(一){基础知识}icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136067030为什么要演绎官方指南,我在上一篇说过了,这次没有废话,直接开始。

Estimator介绍


政安晨:

咱们先看一下Estimator的背景。

TensorFlow的Estimator API是一种高级的机器学习API,用于简化模型的训练、评估和推理过程。它提供了一种更加高层次的抽象,使开发者能够更加专注于模型的架构和数据流水线的设计,而不需要太多地关注底层的实现细节。

Estimator API提供了一套统一的接口,可以用于各种机器学习任务,如分类、回归、聚类等。它具有以下几个主要特点:

  1. 封装了模型的训练、评估和推理过程,提供了一种简单且一致的方式来组织代码和配置模型。

  2. 支持分布式训练,可以轻松地在多个GPU或多台机器上进行训练,以加速模型的训练过程。

  3. 提供了一系列内置的模型,如线性模型、DNN模型、CNN模型等,可以根据任务的需求快速构建模型。

  4. 可以使用预定义的特征列(feature columns)来处理和预处理输入数据,简化了数据准备的过程。

  5. 可以使用高层的tf.data.Dataset API来读取和处理数据,使数据加载和预处理过程更加灵活和高效。

使用Estimator API时,需要定义一个Estimator对象,这个对象包含了模型的结构和参数。然后,通过调用Estimator对象的train()方法来训练模型,evaluate()方法来评估模型,predict()方法来进行预测。在训练模型时,可以通过tf.estimator.TrainSpec对象来指定训练数据的路径和其他相关参数。在评估模型时,可以通过tf.estimator.EvalSpec对象来指定评估数据的路径和其他相关参数。

总之,Estimator API提供了一种简单、灵活且高效的方式来构建、训练和评估机器学习模型,使开发者能够更加专注于模型的设计和业务逻辑。


这篇官方文档介绍了 tf.estimator,它是一种高级 TensorFlow API。Estimator 封装了以下操作:

  • 训练
  • 评估
  • 预测
  • 导出以供使用

您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator 类的类。

有关 API 设计概述,请参阅白皮书。

优势

与 tf.keras.Model 类似,estimator 是模型级别的抽象。tf.estimator 提供了一些目前仍在为 tf.keras 开发中的功能。包括:

  • 基于参数服务器的训练
  • 完整的 TFX 集成

政安晨:

为了后面的演绎,我们先设置一下环境:


Estimator 功能

Estimator 提供了以下优势:

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
  • Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
    • 加载数据
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 摘要

在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。

预制 Estimator

使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。

预制 Estimator 程序结构

依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:

1. 编写一个或多个数据集导入函数。

例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:

  • 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
  • 包含一个或多个标签的张量

例如,以下代码展示了输入函数的基本框架:

def input_fn(dataset):     ...  # manipulate dataset, extracting the feature dict and the label     return feature_dict, label

政安晨:

数据框架其实是这样的,不知为何官方文档中没有给出?

def train_input_fn():titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")titanic = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=32,label_name="survived")titanic_batches = (titanic.cache().repeat().shuffle(500).prefetch(tf.data.AUTOTUNE))return titanic_batches

执行如下:


2. 定义特征列。

每个 tf.feature_column 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:

# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column(   'median_education',   normalizer_fn=lambda x: x - global_education_mean)

3. 实例化相关预制 Estimator。

例如,下面是对名为 LinearClassifier 的预制 Estimator 进行实例化的示例:

# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier(   feature_columns=[population, crime_rate, median_education])

4. 调用训练、评估或推断方法。

例如,所有 Estimator 都会提供一个用于训练模型的 train 方法。

# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)

预制 Estimator 的优势

预制 Estimator 对最佳做法进行了编码,具有以下优势:

  • 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
  • 事件(摘要)编写和通用摘要的最佳做法。

如果不使用预制 Estimator,则您必须自己实现上述功能。

自定义 Estimator

每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。

推荐工作流

  1. 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
  2. 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
  3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
  4. 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

从 Keras 模型创建 Estimator

您可以使用 tf.keras.estimator.model_to_estimator 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。

实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = Falseestimator_model = tf.keras.Sequential([keras_mobilenet_v2,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(1)
])# Compile the model
estimator_model.compile(optimizer='adam',loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])

政安晨执行:

从已编译的 Keras 模型创建 Estimator。Keras 模型的初始模型状态会保留在已创建的 Estimator中:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)

您可以像对待任何其他 Estimator 一样对待派生的 Estimator

IMG_SIZE = 160  # All images will be resized to 160x160def preprocess(image, label):image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return image, label
def train_input_fn(batch_size):data = tfds.load('cats_vs_dogs', as_supervised=True)train_data = data['train']train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)return train_data

要进行训练,可调用 Estimator 的训练函数:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)

同样,要进行评估,可调用 Estimator 的评估函数:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)

有关详细信息,请参阅 tf.keras.estimator.model_to_estimator 文档。

写在最后

其实这一篇中官方指南并不详尽,尤其是最后的训练部分,咱们补充了一些,但仍然存在缺失,我将在后续的文章中以实际项目为例,详细演绎。

这篇关于政安晨:示例演绎TensorFlow的官方指南(二){Estimator}的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/688972

相关文章

Java字符串操作技巧之语法、示例与应用场景分析

《Java字符串操作技巧之语法、示例与应用场景分析》在Java算法题和日常开发中,字符串处理是必备的核心技能,本文全面梳理Java中字符串的常用操作语法,结合代码示例、应用场景和避坑指南,可快速掌握字... 目录引言1. 基础操作1.1 创建字符串1.2 获取长度1.3 访问字符2. 字符串处理2.1 子字

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

Python列表去重的4种核心方法与实战指南详解

《Python列表去重的4种核心方法与实战指南详解》在Python开发中,处理列表数据时经常需要去除重复元素,本文将详细介绍4种最实用的列表去重方法,有需要的小伙伴可以根据自己的需要进行选择... 目录方法1:集合(set)去重法(最快速)方法2:顺序遍历法(保持顺序)方法3:副本删除法(原地修改)方法4:

前端CSS Grid 布局示例详解

《前端CSSGrid布局示例详解》CSSGrid是一种二维布局系统,可以同时控制行和列,相比Flex(一维布局),更适合用在整体页面布局或复杂模块结构中,:本文主要介绍前端CSSGri... 目录css Grid 布局详解(通俗易懂版)一、概述二、基础概念三、创建 Grid 容器四、定义网格行和列五、设置行

Node.js 数据库 CRUD 项目示例详解(完美解决方案)

《Node.js数据库CRUD项目示例详解(完美解决方案)》:本文主要介绍Node.js数据库CRUD项目示例详解(完美解决方案),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考... 目录项目结构1. 初始化项目2. 配置数据库连接 (config/db.js)3. 创建模型 (models/

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Spring LDAP目录服务的使用示例

《SpringLDAP目录服务的使用示例》本文主要介绍了SpringLDAP目录服务的使用示例... 目录引言一、Spring LDAP基础二、LdapTemplate详解三、LDAP对象映射四、基本LDAP操作4.1 查询操作4.2 添加操作4.3 修改操作4.4 删除操作五、认证与授权六、高级特性与最佳

PyInstaller打包selenium-wire过程中常见问题和解决指南

《PyInstaller打包selenium-wire过程中常见问题和解决指南》常用的打包工具PyInstaller能将Python项目打包成单个可执行文件,但也会因为兼容性问题和路径管理而出现各种运... 目录前言1. 背景2. 可能遇到的问题概述3. PyInstaller 打包步骤及参数配置4. 依赖

Nginx中配置HTTP/2协议的详细指南

《Nginx中配置HTTP/2协议的详细指南》HTTP/2是HTTP协议的下一代版本,旨在提高性能、减少延迟并优化现代网络环境中的通信效率,本文将为大家介绍Nginx配置HTTP/2协议想详细步骤,需... 目录一、HTTP/2 协议概述1.HTTP/22. HTTP/2 的核心特性3. HTTP/2 的优