TfLite: TensorFlow模型格式和Post-training quantization

2024-06-03 14:58

本文主要是介绍TfLite: TensorFlow模型格式和Post-training quantization,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow的模型格式

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式,只要符合规范的模型都可以轻易部署到在线服务或移动设备上,这里简单列举一下。

  • Checkpoint: 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
  • GraphDef:用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
  • SavedModel:使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
  • FrozenGraph:使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
  • TFLite:基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异

TensorFlow的模型格式有以上几种,由不同工具生成,有不同的用途。使用tensorlfow底层API和keras的方式不同,但这些格式和是否为keras没有关系。SavedModel和FrozenGraph是两个不同的格式

 

 Saving and Serializing Models with TensorFlow Keras

https://www.tensorflow.org/beta/guide/keras/saving_and_serializing

 Whole-model saving

# Save the model
model.save('path_to_my_model.h5')# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model.h5')

model保存为h5格式,从model文件加载model

Export to SavedModel

You can also export a whole model to the TensorFlow SavedModel format. SavedModel is a standalone serialization format for Tensorflow objects, supported by TensorFlow serving as well as TensorFlow implementations other than Python. 

# Export the model to a SavedModel
keras.experimental.export_saved_model(model, 'path_to_saved_model')# Recreate the exact same model
new_model = keras.experimental.load_from_saved_model('path_to_saved_model')# Check that the state is preserved
new_predictions = new_model.predict(x_test)

 Architecture-only saving

 

config = model.get_config()

You can alternatively use to_json() from from_json(), which uses a JSON string to store the config instead of a Python dict. This is useful to save the config to disk.

json_config = model.to_json()

 Weights-only saving

weights = model.get_weights()  # Retrieves the state of the model.
model.set_weights(weights)  # Sets the state of the model.

 Model Optimization 的发展历史

最初提出(只对weight量化)

https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3

Introducing the Model Optimization Toolkit for TensorFlow

也就是post-training quantization via “hybrid operations”, 混合数据类型运算(只对weight量化)

如下面kernel/filter数据类型是kTfLiteUInt8而输出等是kTfLiteFloat32,所以是 hybrid

Tensor   1 img                  kTfLiteFloat32  kTfLiteArenaRw       3136 bytes ( 0.0 MB)  1 784
Tensor   2 mnist_model/dense/MatMtranspose kTfLiteUInt8   kTfLiteMmapRo      50176 bytes ( 0.0 MB)  64 784
Tensor   3 mnist_model/dense/MatMul_bias kTfLiteFloat32   kTfLiteMmapRo        256 bytes ( 0.0 MB)  64
Tensor   4 mnist_model/dense/Relu kTfLiteFloat32  kTfLiteArenaRw        256 bytes ( 0.0 MB)  1 64
 

当前最新(对weight和activation量化)

https://medium.com/tensorflow/tensorflow-model-optimization-toolkit-post-training-integer-quantization-b4964a1ea9ba

TensorFlow Model Optimization Toolkit — Post-Training Integer Quantization(后面提供了教程链接)

后上面的区别,需要输入样本数据集representative_dataset,

def representative_dataset_gen():
  data = tfds.load(...)

  for _ in range(num_calibration_steps):
    image, = data.take(1)
    yield [image]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = tf.lite.RepresentativeDataset(
    representative_dataset_gen)

量化后遇到的问题

并不是量化后的模型文件就能执行推断成功,还有看算子的对量化的支持实现。

full_connect对量化的支持

tflite 支持 hybrid混合运算

  TfLiteRegistration* Register_FULLY_CONNECTED() {
    return Register_FULLY_CONNECTED_PIE();
 }

  TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
    static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
                                   fully_connected::Prepare,
                                   fully_connected::Eval<fully_connected::kPie>};
    return &r; 
  }

  template <KernelType kernel_type>
  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
    OpData* data = reinterpret_cast<OpData*>(node->user_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteUInt8:
    if (params->weights_format ==
                   kTfLiteFullyConnectedWeightsFormatDefault) {
          printf("EvalQuantized<kernel_type>\n");
          return EvalQuantized<kernel_type>(context, node, params, data, input,
                                            filter, bias, output);
        } 
 }

  template <KernelType kernel_type>
  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
  
    int32_t input_offset = -input->params.zero_point;
    int32_t filter_offset = -filter->params.zero_point;
    int32_t output_offset = output->params.zero_point;

    if (kernel_type == kPie && input->type == kTfLiteFloat32) {
      printf("kPie:\n");
      // Pie currently only supports quantized models and float inputs/outputs.
      TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
      TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
      return EvalHybrid(context, node, params, data, input, filter, bias,
                        input_quantized, scaling_factors, output);
    }
  }

Tflite for mcu就不支持hybrid,最后异常

  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    TfLiteType data_type = input->type;
    OpData local_data_object;
    OpData* data = &local_data_object;
    TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
                                          filter, bias, output, data));
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteFloat32:
        return EvalFloat(context, node, params, data, input, filter, bias,
                         output);
      case kTfLiteUInt8:
        return EvalQuantized(context, node, params, data, input, filter, bias,
                             output);

 }

  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    const int32_t input_offset = -input->params.zero_point;
    const int32_t filter_offset = -filter->params.zero_point;
    const int32_t output_offset = output->params.zero_point;
  
    tflite::FullyConnectedParams op_params;
    op_params.input_offset = input_offset;                                                                      
    op_params.weights_offset = filter_offset;
    op_params.output_offset = output_offset;
    op_params.output_multiplier = data->output_multiplier;
    // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
    op_params.output_shift = -data->output_shift;
    op_params.quantized_activation_min = data->output_activation_min;
    op_params.quantized_activation_max = data->output_activation_max;
  
  #define TF_LITE_FULLY_CONNECTED(output_data_type)                      \
    reference_ops::FullyConnected(                                       \
        op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
        GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
        GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
        GetTensorShape(output), GetTensorData<output_data_type>(output), \
        nullptr)
    switch (output->type) {// float类型
      case kTfLiteUI t8: 
        TF_LITE_FULLY_CONNECTED(uint8_t);
        break;
      case kTfLiteInt16:
        TF_LITE_FULLY_CONNECTED(int16_t);
        break;
      default:
        printf("output type: %d\n", output->type);
        context->ReportError(
            context,
            "Quantized FullyConnected expects output data type uint8 or int16");
  }
 

这篇关于TfLite: TensorFlow模型格式和Post-training quantization的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1027318

相关文章

C#使用HttpClient进行Post请求出现超时问题的解决及优化

《C#使用HttpClient进行Post请求出现超时问题的解决及优化》最近我的控制台程序发现有时候总是出现请求超时等问题,通常好几分钟最多只有3-4个请求,在使用apipost发现并发10个5分钟也... 目录优化结论单例HttpClient连接池耗尽和并发并发异步最终优化后优化结论我直接上优化结论吧,

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

IDEA如何将String类型转json格式

《IDEA如何将String类型转json格式》在Java中,字符串字面量中的转义字符会被自动转换,但通过网络获取的字符串可能不会自动转换,为了解决IDEA无法识别JSON字符串的问题,可以在本地对字... 目录问题描述问题原因解决方案总结问题描述最近做项目需要使用Ai生成json,可生成String类型

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

SpringBoot中Get请求和POST请求接收参数示例详解

《SpringBoot中Get请求和POST请求接收参数示例详解》文章详细介绍了SpringBoot中Get请求和POST请求的参数接收方式,包括方法形参接收参数、实体类接收参数、HttpServle... 目录1、Get请求1.1 方法形参接收参数 这种方式一般适用参数比较少的情况,并且前后端参数名称必须

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}