TfLite: TensorFlow模型格式和Post-training quantization

2024-06-03 14:58

本文主要是介绍TfLite: TensorFlow模型格式和Post-training quantization,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow的模型格式

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式,只要符合规范的模型都可以轻易部署到在线服务或移动设备上,这里简单列举一下。

  • Checkpoint: 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
  • GraphDef:用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
  • SavedModel:使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
  • FrozenGraph:使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
  • TFLite:基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异

TensorFlow的模型格式有以上几种,由不同工具生成,有不同的用途。使用tensorlfow底层API和keras的方式不同,但这些格式和是否为keras没有关系。SavedModel和FrozenGraph是两个不同的格式

 

 Saving and Serializing Models with TensorFlow Keras

https://www.tensorflow.org/beta/guide/keras/saving_and_serializing

 Whole-model saving

# Save the model
model.save('path_to_my_model.h5')# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model.h5')

model保存为h5格式,从model文件加载model

Export to SavedModel

You can also export a whole model to the TensorFlow SavedModel format. SavedModel is a standalone serialization format for Tensorflow objects, supported by TensorFlow serving as well as TensorFlow implementations other than Python. 

# Export the model to a SavedModel
keras.experimental.export_saved_model(model, 'path_to_saved_model')# Recreate the exact same model
new_model = keras.experimental.load_from_saved_model('path_to_saved_model')# Check that the state is preserved
new_predictions = new_model.predict(x_test)

 Architecture-only saving

 

config = model.get_config()

You can alternatively use to_json() from from_json(), which uses a JSON string to store the config instead of a Python dict. This is useful to save the config to disk.

json_config = model.to_json()

 Weights-only saving

weights = model.get_weights()  # Retrieves the state of the model.
model.set_weights(weights)  # Sets the state of the model.

 Model Optimization 的发展历史

最初提出(只对weight量化)

https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3

Introducing the Model Optimization Toolkit for TensorFlow

也就是post-training quantization via “hybrid operations”, 混合数据类型运算(只对weight量化)

如下面kernel/filter数据类型是kTfLiteUInt8而输出等是kTfLiteFloat32,所以是 hybrid

Tensor   1 img                  kTfLiteFloat32  kTfLiteArenaRw       3136 bytes ( 0.0 MB)  1 784
Tensor   2 mnist_model/dense/MatMtranspose kTfLiteUInt8   kTfLiteMmapRo      50176 bytes ( 0.0 MB)  64 784
Tensor   3 mnist_model/dense/MatMul_bias kTfLiteFloat32   kTfLiteMmapRo        256 bytes ( 0.0 MB)  64
Tensor   4 mnist_model/dense/Relu kTfLiteFloat32  kTfLiteArenaRw        256 bytes ( 0.0 MB)  1 64
 

当前最新(对weight和activation量化)

https://medium.com/tensorflow/tensorflow-model-optimization-toolkit-post-training-integer-quantization-b4964a1ea9ba

TensorFlow Model Optimization Toolkit — Post-Training Integer Quantization(后面提供了教程链接)

后上面的区别,需要输入样本数据集representative_dataset,

def representative_dataset_gen():
  data = tfds.load(...)

  for _ in range(num_calibration_steps):
    image, = data.take(1)
    yield [image]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = tf.lite.RepresentativeDataset(
    representative_dataset_gen)

量化后遇到的问题

并不是量化后的模型文件就能执行推断成功,还有看算子的对量化的支持实现。

full_connect对量化的支持

tflite 支持 hybrid混合运算

  TfLiteRegistration* Register_FULLY_CONNECTED() {
    return Register_FULLY_CONNECTED_PIE();
 }

  TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
    static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
                                   fully_connected::Prepare,
                                   fully_connected::Eval<fully_connected::kPie>};
    return &r; 
  }

  template <KernelType kernel_type>
  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
    OpData* data = reinterpret_cast<OpData*>(node->user_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteUInt8:
    if (params->weights_format ==
                   kTfLiteFullyConnectedWeightsFormatDefault) {
          printf("EvalQuantized<kernel_type>\n");
          return EvalQuantized<kernel_type>(context, node, params, data, input,
                                            filter, bias, output);
        } 
 }

  template <KernelType kernel_type>
  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
  
    int32_t input_offset = -input->params.zero_point;
    int32_t filter_offset = -filter->params.zero_point;
    int32_t output_offset = output->params.zero_point;

    if (kernel_type == kPie && input->type == kTfLiteFloat32) {
      printf("kPie:\n");
      // Pie currently only supports quantized models and float inputs/outputs.
      TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
      TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
      return EvalHybrid(context, node, params, data, input, filter, bias,
                        input_quantized, scaling_factors, output);
    }
  }

Tflite for mcu就不支持hybrid,最后异常

  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    TfLiteType data_type = input->type;
    OpData local_data_object;
    OpData* data = &local_data_object;
    TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
                                          filter, bias, output, data));
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteFloat32:
        return EvalFloat(context, node, params, data, input, filter, bias,
                         output);
      case kTfLiteUInt8:
        return EvalQuantized(context, node, params, data, input, filter, bias,
                             output);

 }

  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    const int32_t input_offset = -input->params.zero_point;
    const int32_t filter_offset = -filter->params.zero_point;
    const int32_t output_offset = output->params.zero_point;
  
    tflite::FullyConnectedParams op_params;
    op_params.input_offset = input_offset;                                                                      
    op_params.weights_offset = filter_offset;
    op_params.output_offset = output_offset;
    op_params.output_multiplier = data->output_multiplier;
    // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
    op_params.output_shift = -data->output_shift;
    op_params.quantized_activation_min = data->output_activation_min;
    op_params.quantized_activation_max = data->output_activation_max;
  
  #define TF_LITE_FULLY_CONNECTED(output_data_type)                      \
    reference_ops::FullyConnected(                                       \
        op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
        GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
        GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
        GetTensorShape(output), GetTensorData<output_data_type>(output), \
        nullptr)
    switch (output->type) {// float类型
      case kTfLiteUI t8: 
        TF_LITE_FULLY_CONNECTED(uint8_t);
        break;
      case kTfLiteInt16:
        TF_LITE_FULLY_CONNECTED(int16_t);
        break;
      default:
        printf("output type: %d\n", output->type);
        context->ReportError(
            context,
            "Quantized FullyConnected expects output data type uint8 or int16");
  }
 

这篇关于TfLite: TensorFlow模型格式和Post-training quantization的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1027318

相关文章

HTML5表格语法格式详解

《HTML5表格语法格式详解》在HTML语法中,表格主要通过table、tr和td3个标签构成,本文通过实例代码讲解HTML5表格语法格式,感兴趣的朋友一起看看吧... 目录一、表格1.表格语法格式2.表格属性 3.例子二、不规则表格1.跨行2.跨列3.例子一、表格在html语法中,表格主要通过< tab

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Python将博客内容html导出为Markdown格式

《Python将博客内容html导出为Markdown格式》Python将博客内容html导出为Markdown格式,通过博客url地址抓取文章,分析并提取出文章标题和内容,将内容构建成html,再转... 目录一、为什么要搞?二、准备如何搞?三、说搞咱就搞!抓取文章提取内容构建html转存markdown

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

如何自定义Nginx JSON日志格式配置

《如何自定义NginxJSON日志格式配置》Nginx作为最流行的Web服务器之一,其灵活的日志配置能力允许我们根据需求定制日志格式,本文将详细介绍如何配置Nginx以JSON格式记录访问日志,这种... 目录前言为什么选择jsON格式日志?配置步骤详解1. 安装Nginx服务2. 自定义JSON日志格式各

python dict转换成json格式的实现

《pythondict转换成json格式的实现》本文主要介绍了pythondict转换成json格式的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下... 一开始你变成字典格式data = [ { 'a' : 1, 'b' : 2, 'c编程' : 3,

Python中Windows和macOS文件路径格式不一致的解决方法

《Python中Windows和macOS文件路径格式不一致的解决方法》在Python中,Windows和macOS的文件路径字符串格式不一致主要体现在路径分隔符上,这种差异可能导致跨平台代码在处理文... 目录方法 1:使用 os.path 模块方法 2:使用 pathlib 模块(推荐)方法 3:统一使

Java中使用注解校验手机号格式的详细指南

《Java中使用注解校验手机号格式的详细指南》在现代的Web应用开发中,数据校验是一个非常重要的环节,本文将详细介绍如何在Java中使用注解对手机号格式进行校验,感兴趣的小伙伴可以了解下... 目录1. 引言2. 数据校验的重要性3. Java中的数据校验框架4. 使用注解校验手机号格式4.1 @NotBl

Python批量调整Word文档中的字体、段落间距及格式

《Python批量调整Word文档中的字体、段落间距及格式》这篇文章主要为大家详细介绍了如何使用Python的docx库来批量处理Word文档,包括设置首行缩进、字体、字号、行间距、段落对齐方式等,需... 目录关键代码一级标题设置  正文设置完整代码运行结果最近关于批处理格式的问题我查了很多资料,但是都没