基于Keras的模型量化(PTQ、QAT)

2024-03-14 07:12
文章标签 模型 量化 keras ptq qat

本文主要是介绍基于Keras的模型量化(PTQ、QAT),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

对PTQ和QAT的详细解释在这篇哦:
《模型量化(三)—— 量化感知训练QAT(pytorch)》

本文给的代码是基于tensorflow

目录

  • PTQ
    • 只量化权重
    • 权重和激活值全量化
  • QAT
    • 套路创建和训练模型
    • 用QAT克隆和微调预训练模型
    • 量化模型
    • 评估TF和TFLite

PTQ

只量化权重

只是优化了模型大小,对于模型的计算没什么优化,因为W * X时,W要反量化为浮点进行运算,相当于还增加了反量化这一累赘操作…

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

权重和激活值全量化

通过量化权重和激活,可以同时改善功耗和时延,因为最关键的密集部分W * X使用 8 位而不是浮点数进行计算。这需要一个较小的calibration数据集来计算激活值反量化时的S和Z操作。

import tensorflow as tfdef representative_dataset_gen():for _ in range(num_calibration_steps):# Get sample input data as a numpy array in a method of your choosing.yield [input]converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

详细内容参考官方的说明:
https://www.tensorflow.org/model_optimization/guide/quantization/post_training?hl=zh-cn

 

QAT

套路创建和训练模型

!pip install -q tensorflow
!pip install -q tensorflow-model-optimizationimport tempfile
import os
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0# Define the model architecture.
model = keras.Sequential([keras.layers.InputLayer(input_shape=(28, 28)),keras.layers.Reshape(target_shape=(28, 28, 1)),keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Flatten(),keras.layers.Dense(10)
])# Train the digit classification model
model.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.fit(train_images,train_labels,epochs=1,validation_split=0.1,
)

 

用QAT克隆和微调预训练模型

import tensorflow_model_optimization as tfmotquantize_model = tfmot.quantization.keras.quantize_model# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])q_aware_model.summary()##############################################################
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]q_aware_model.fit(train_images_subset, train_labels_subset,batch_size=500, epochs=1, validation_split=0.1)##############################################################
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
#现在模型还是float32,不是int8

 

量化模型

用上面介绍的PTQ方法

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]quantized_tflite_model = converter.convert()

 

评估TF和TFLite

import numpy as npdef evaluate_model(interpreter):input_index = interpreter.get_input_details()[0]["index"]output_index = interpreter.get_output_details()[0]["index"]# Run predictions on every image in the "test" dataset.prediction_digits = []for i, test_image in enumerate(test_images):if i % 1000 == 0:print('Evaluated on {n} results so far.'.format(n=i))# Pre-processing: add batch dimension and convert to float32 to match with# the model's input data format.test_image = np.expand_dims(test_image, axis=0).astype(np.float32)interpreter.set_tensor(input_index, test_image)# Run inference.interpreter.invoke()# Post-processing: remove batch dimension and find the digit with highest# probability.output = interpreter.tensor(output_index)digit = np.argmax(output()[0])prediction_digits.append(digit)print('\n')# Compare prediction results with ground truth labels to calculate accuracy.prediction_digits = np.array(prediction_digits)accuracy = (prediction_digits == test_labels).mean()return accuracy###################################################################
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()test_accuracy = evaluate_model(interpreter)print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

这篇关于基于Keras的模型量化(PTQ、QAT)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/807616

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选