torch.fx量化resnet18_cifar10

2024-03-30 05:20
文章标签 量化 torch fx cifar10 resnet18

本文主要是介绍torch.fx量化resnet18_cifar10,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

运行环境

11th Gen Intel® Core™ i7-11370H @ 3.30GHz

从HuggingFace上下载训练好的模型 链接

import timm
model = timm.create_model("resnet18_cifar10", pretrained=True)

加载模型

import torch
from torch import nn
from torchvision.models.resnet import resnet18model=resnet18(pretrained=True) #加载pytorch官方提供的预训练模型结构
#手动修改网络层维度,与自定义模型结构相同
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()  # type: ignore
model.fc=nn.Linear(model.fc.in_features,10)
#根据模型结构将模型的state_dict(参数)加载到模型中
#torch.load函数用pickle的反序列化功能将保存的对象文件加载到内存中。这个函数还支持将数据加载到指定的设备上。
#model.load_state_dict函数用于使用反序列化的state_dict加载模型的参数字典。
model.load_state_dict(torch.load("resnet18_cifar10.pth",map_location='cpu'))
model.to(torch.device("cpu"))
model.eval()#设置模型为推理模式

DataLoader加载Cifar10数据集

from torch.utils.data import DataLoader# 设置mean和scale
transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])#下载数据集
train_data = torchvision.datasets.CIFAR10(root='data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='data', train=False, transform=transform_test,download=True)# DataLoader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

量化校准并量化

# 校准恢复精度
model_to_quantize=copy.deepcopy(model)
prepared_model=prepare_fx(model_to_quantize,qconfig_mapping,example_inputs=torch.randn([1,3,224,224]))
prepared_model.eval()
#量化校准
with torch.inference_mode():for inputs,labels in test_dataloader:prepared_model(inputs)
quantized_recover_model=convert_fx(prepared_model)
print(f"quantized model {quantized_recover_model.graph.print_tabular()}")

保存量化后的模型:

script_module = torch.jit.trace(quantized_recover_model, example_inputs=torch.randn([1, 3, 224, 224]))
torch.jit.save(script_module, "quant_model.pth")

加载序列化的量化模型并测试模型精度:

quantized_recover_model = torch.jit.load("C:\\Users\\tanfengfeng\\PycharmProjects\\pythonProject1\\quant_model.pth")
# quantized_recover_model.load_state_dict(torch.load("C:\\Users\\tanfengfeng\\PycharmProjects\\pythonProject1\\tmp.pt", map_location='cpu'))
with torch.autograd.profiler.profile(enabled=True, use_cuda=False, record_shapes=False, profile_memory=False) as prof:test(quantized_recover_model, test_dataloader, device='cpu')
print(prof.table())

测试模型精度:

def test(model, test_dataloader, device):best_acc = 0model.eval()test_loss = 0correct = 0total = 0test_acc = 0with torch.no_grad():#设置禁止计算梯度for batch_idx, (inputs, targets) in enumerate(test_dataloader):#从DataLoader获取数据inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)#前向传播criterion = nn.CrossEntropyLoss()loss = criterion(outputs, targets)#计算交叉熵损失test_loss += loss.item()# item() 获取标量的数值_, predicted = outputs.max(1) # 返回第1个维度上的最大的(元素值,索引) predicted为每个样本预测的分类Idtotal += targets.size(0)correct += predicted.eq(targets).sum().item()test_acc = correct / totalprint('[INFO] Test Accurancy: {:.3f}'.format(test_acc), '\n')

完整代码:

# https://juejin.cn/post/7178317867492835383#heading-7
# https://huggingface.co/edadaltocg/resnet18_cifar10import os
import copy
import timeimport torch
from torch import nnimport torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mappingfrom torch.utils.data import DataLoaderdef test(model, test_dataloader, device):best_acc = 0model.eval()test_loss = 0correct = 0total = 0test_acc = 0with torch.no_grad():  # 设置禁止计算梯度for batch_idx, (inputs, targets) in enumerate(test_dataloader):  # 从DataLoader获取数据inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)  # 前向传播criterion = nn.CrossEntropyLoss()loss = criterion(outputs, targets)  # 计算交叉熵损失test_loss += loss.item()  # item() 获取标量的数值_, predicted = outputs.max(1)  # 返回第1个维度上的最大的(元素值,索引) predicted为每个样本预测的分类Idtotal += targets.size(0)correct += predicted.eq(targets).sum().item()test_acc = correct / totalprint('[INFO] Test Accurancy: {:.3f}'.format(test_acc), '\n')def print_size_of_model(model):torch.save(model.state_dict(), "tmp.pt")print(f"The model size:{os.path.getsize('tmp.pt') / 1e6}MB")model = resnet18(pretrained=True)
# 修改模型
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()  # type: ignore
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load("C:\\Users\\tanfengfeng\\.cache\\torch\\hub\\checkpoints\\resnet18_cifar10.pth", map_location='cpu'))
model.to(torch.device("cpu"))
model.eval()# 设置mean和scale
transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='data', train=False, transform=transform_test, download=True)print("训练集的长度:{}".format(len(train_data)))
print("测试集的长度:{}".format(len(test_data)))# DataLoader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)#量化
torch.backends.quantized.engine = 'fbgemm'  # 设置量化后端
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
# 校准量化精度
model_to_quantize = copy.deepcopy(model)
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs=torch.randn([1, 3, 224, 224]))
prepared_model.eval()
with torch.inference_mode():for inputs, labels in test_dataloader:prepared_model(inputs)quantized_recover_model = convert_fx(prepared_model)
# print(f"quantized model {quantized_recover_model.graph.print_tabular()}")
script_module = torch.jit.trace(quantized_recover_model, example_inputs=torch.randn([1, 3, 224, 224]))
torch.jit.save(script_module, "quant_model.pth")
# print_size_of_model(prepared_model)
# print_size_of_model(quantized_recover_model)#测试FP32模型精度和耗时
with torch.autograd.profiler.profile(enabled=True, use_cuda=False, record_shapes=False, profile_memory=False) as prof:test(model, test_dataloader, device='cpu')
print(prof.table())#测试int8模型精度和耗时
quantized_recover_model = torch.jit.load("C:\\Users\\tanfengfeng\\PycharmProjects\\pythonProject1\\quant_model.pth")
with torch.autograd.profiler.profile(enabled=True, use_cuda=False, record_shapes=False, profile_memory=False) as prof:test(quantized_recover_model, test_dataloader, device='cpu')
print(prof.table())

运行结果截图:
FP32:
在这里插入图片描述
在这里插入图片描述
量化后:
在这里插入图片描述
在这里插入图片描述

总结

量化后推理时间快了一倍,并且模型精度几乎没有损耗。

这篇关于torch.fx量化resnet18_cifar10的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/860582

相关文章

javafx 如何将项目打包为 Windows 的可执行文件exe

《javafx如何将项目打包为Windows的可执行文件exe》文章介绍了三种将JavaFX项目打包为.exe文件的方法:方法1使用jpackage(适用于JDK14及以上版本),方法2使用La... 目录方法 1:使用 jpackage(适用于 JDK 14 及更高版本)方法 2:使用 Launch4j(

JavaFX应用更新检测功能(在线自动更新方案)

JavaFX开发的桌面应用属于C端,一般来说需要版本检测和自动更新功能,这里记录一下一种版本检测和自动更新的方法。 1. 整体方案 JavaFX.应用版本检测、自动更新主要涉及一下步骤: 读取本地应用版本拉取远程版本并比较两个版本如果需要升级,那么拉取更新历史弹出升级控制窗口用户选择升级时,拉取升级包解压,重启应用用户选择忽略时,本地版本标志为忽略版本用户选择取消时,隐藏升级控制窗口 2.

JavaFX环境的搭建和一个简单的例子

之前在网上搜了很多与javaFX相关的资料,都说要在Eclepse上要安装sdk插件什么的,反正就是乱七八糟的一大片,最后还是没搞成功,所以我在这里写下我搭建javaFX成功的环境给大家做一个参考吧。希望能帮助到你们! 1.首先要保证你的jdk版本能够支持JavaFX的开发,jdk-7u25版本以上的都能支持,最好安装jdk8吧,因为jdk8对支持JavaFX有新的特性了,比如:3D等;

Unity 资源 之 Super Confetti FX:点亮项目的璀璨粒子之光

Unity 资源 之 Super Confetti FX:点亮项目的璀璨粒子之光 一,前言二,资源包内容三,免费获取资源包 一,前言 在创意的世界里,每一个细节都能决定一个项目的独特魅力。今天,要向大家介绍一款令人惊艳的粒子效果包 ——Super Confetti FX。 二,资源包内容 💥充满活力与动态,是 Super Confetti FX 最显著的标签。它宛如一位

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

量化交易面试:什么是连贯风险度量?

连贯风险度量(Coherent Risk Measures)是金融风险管理中的一个重要概念,旨在提供一种合理且一致的方式来评估和量化风险。连贯风险度量的提出是为了克服传统风险度量方法(如VaR,风险价值)的一些局限性。以下是对连贯风险度量的详细解释: 基本概念: 连贯风险度量是指满足特定公理的风险度量方法,这些公理确保了风险评估的一致性和合理性。 这些公理包括:非负性、次可加性、同质性和单调

Matlab)实现HSV非等间隔量化--相似判断:欧式距离--输出图片-

%************************************************************************** %                                 图像检索——提取颜色特征 %HSV空间颜色直方图(将RGB空间转化为HS

【Java】 在GUI开发中JavaFX是否仍占有一席之地?

文章目录 引言什么是JavaFX?如何使用JavaFX开发桌面应用程序1. 环境搭建2. 创建项目3. 设计UI界面4. 编写控制器代码5. 运行应用程序 使用JavaFX开发的好处1. 现代化的UI组件2. 跨平台支持3. 易于维护和扩展 JavaFX的优缺点优点缺点 JavaFX与Java包的兼容性JavaFX 8(随Java 8发布)JavaFX 9(2017)JavaFX 10(20

期货赫兹量化-种群优化算法:进化策略,(μ,λ)-ES 和 (μ+λ)-ES

进化策略(Evolution Strategies, ES)是一种启发式算法,旨在模仿自然选择的过程来解决复杂的优化问题,尤其在没有显式解、或搜索空间巨大的情况下表现良好。基于自然界的进化原理,进化策略通过突变、选择等遗传算子迭代生成解,并最终寻求全局最优解。 进化策略通常基于两个核心机制:突变和选择。突变是对当前解进行随机扰动,而选择则用于保留适应度更高的个体。本文详细介绍了 (μ,λ)-ES