基于Keras的模型量化(PTQ、QAT)

2024-03-14 07:12
文章标签 模型 量化 keras ptq qat

本文主要是介绍基于Keras的模型量化(PTQ、QAT),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

对PTQ和QAT的详细解释在这篇哦:
《模型量化(三)—— 量化感知训练QAT(pytorch)》

本文给的代码是基于tensorflow

目录

  • PTQ
    • 只量化权重
    • 权重和激活值全量化
  • QAT
    • 套路创建和训练模型
    • 用QAT克隆和微调预训练模型
    • 量化模型
    • 评估TF和TFLite

PTQ

只量化权重

只是优化了模型大小,对于模型的计算没什么优化,因为W * X时,W要反量化为浮点进行运算,相当于还增加了反量化这一累赘操作…

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

权重和激活值全量化

通过量化权重和激活,可以同时改善功耗和时延,因为最关键的密集部分W * X使用 8 位而不是浮点数进行计算。这需要一个较小的calibration数据集来计算激活值反量化时的S和Z操作。

import tensorflow as tfdef representative_dataset_gen():for _ in range(num_calibration_steps):# Get sample input data as a numpy array in a method of your choosing.yield [input]converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

详细内容参考官方的说明:
https://www.tensorflow.org/model_optimization/guide/quantization/post_training?hl=zh-cn

 

QAT

套路创建和训练模型

!pip install -q tensorflow
!pip install -q tensorflow-model-optimizationimport tempfile
import os
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0# Define the model architecture.
model = keras.Sequential([keras.layers.InputLayer(input_shape=(28, 28)),keras.layers.Reshape(target_shape=(28, 28, 1)),keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Flatten(),keras.layers.Dense(10)
])# Train the digit classification model
model.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.fit(train_images,train_labels,epochs=1,validation_split=0.1,
)

 

用QAT克隆和微调预训练模型

import tensorflow_model_optimization as tfmotquantize_model = tfmot.quantization.keras.quantize_model# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])q_aware_model.summary()##############################################################
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]q_aware_model.fit(train_images_subset, train_labels_subset,batch_size=500, epochs=1, validation_split=0.1)##############################################################
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
#现在模型还是float32,不是int8

 

量化模型

用上面介绍的PTQ方法

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]quantized_tflite_model = converter.convert()

 

评估TF和TFLite

import numpy as npdef evaluate_model(interpreter):input_index = interpreter.get_input_details()[0]["index"]output_index = interpreter.get_output_details()[0]["index"]# Run predictions on every image in the "test" dataset.prediction_digits = []for i, test_image in enumerate(test_images):if i % 1000 == 0:print('Evaluated on {n} results so far.'.format(n=i))# Pre-processing: add batch dimension and convert to float32 to match with# the model's input data format.test_image = np.expand_dims(test_image, axis=0).astype(np.float32)interpreter.set_tensor(input_index, test_image)# Run inference.interpreter.invoke()# Post-processing: remove batch dimension and find the digit with highest# probability.output = interpreter.tensor(output_index)digit = np.argmax(output()[0])prediction_digits.append(digit)print('\n')# Compare prediction results with ground truth labels to calculate accuracy.prediction_digits = np.array(prediction_digits)accuracy = (prediction_digits == test_labels).mean()return accuracy###################################################################
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()test_accuracy = evaluate_model(interpreter)print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

这篇关于基于Keras的模型量化(PTQ、QAT)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/807616

相关文章

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首