Vitis AI 基本认知(Tiny-VGG 标签获取+预测后处理)

2024-08-29 06:04

本文主要是介绍Vitis AI 基本认知(Tiny-VGG 标签获取+预测后处理),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

1. 简介

2. 解析

2.1 获取标签

2.1.1 载入数据集

2.1.2 标签-Index

2.1.3 保存和读取类别标签

2.2 读取单个图片

2.3 载入模型并推理

2.3.1 tiny-vgg 模型结构

2.3.2 运行推理

 2.4 置信度柱状图

2.5 预测标签

3. 完整代码

4. 总结


1. 简介

本博文在《Vitis AI 基本认知(Tiny-VGG 项目代码详解)-CSDN博客》基础上,详细介绍如何使用TensorFlow框架进行单个图片的推理,从获取和处理数据集的标签开始,到模型的加载与推理,再到结果的可视化展示。关键信息如下:

  • 获取数据集的标签
  • 保存和读取类别标签
  • 加载模型并推理
  • 绘制图像
  • 使用中文标签
  • 置信度柱状图

2. 解析

2.1 获取标签

2.1.1 载入数据集

通过 image_dataset_from_directory 方法

vali_dataset = tf.keras.preprocessing.image_dataset_from_directory('./dataset/class_10_val/val_images/',image_size=(64, 64),batch_size=32)

取出一个图片,并查看其标签:

for images, labels in vali_dataset.take(1):# 取出第一个图片和标签image = images[0].numpy().astype("uint8")label = labels[0].numpy()# 显示图片plt.figure(figsize=(2, 2))plt.imshow(image)plt.title(f"Label: {label}")plt.axis('off')plt.show()

2.1.2 标签-Index

查看类别标签及其 Index:

class_names = vali_dataset.class_namesfor i, class_name in enumerate(class_names):print(f"Class name: {class_name:<4}, Index: {i}")
---
Class name: 咖啡   , Index: 0
Class name: 小熊猫 , Index: 1
Class name: 披萨   , Index: 2
Class name: 救生艇 , Index: 3
Class name: 校车   , Index: 4
Class name: 橙子   , Index: 5
Class name: 灯笼椒 , Index: 6
Class name: 瓢虫   , Index: 7
Class name: 考拉   , Index: 8
Class name: 跑车   , Index: 9

类别标签对应的 one-hot 标签:

for index, class_name in enumerate(class_names):one_hot = tf.one_hot(index, len(class_names)).numpy()print(f"Class: {class_name}, One-hot: {one_hot}")
---
Class: 咖啡  , One-hot: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 小熊猫, One-hot: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 披萨  , One-hot: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Class: 救生艇, One-hot: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
Class: 校车  , One-hot: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
Class: 橙子  , One-hot: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Class: 灯笼椒, One-hot: [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
Class: 瓢虫  , One-hot: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
Class: 考拉  , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
Class: 跑车  , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]

2.1.3 保存和读取类别标签

将类别标签写入文本文档:

with open('tiny_VGG_class_names.txt', 'w') as file:for class_name in class_names:file.write(f"{class_name}\n")

从文本文档中读取类别标签: 

with open('tiny_VGG_class_names.txt', 'r') as file:class_names = [line.strip() for line in file]print(class_names)
---
['咖啡', '小熊猫', '披萨', '救生艇', '校车', '橙子', '灯笼椒', '瓢虫', '考拉', '跑车']

2.2 读取单个图片

读取图片,并显示在 Jupyter Lab 中:

img = cv2.imread('./dataset/class_10_val/val_images/橙子/val_1067.JPEG')plt.figure(figsize=(2, 2))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

 对图片归一化操作:

normalization_layer = tf.keras.layers.Rescaling(1./255)
img_norm = normalization_layer(img)
img_norm = np.expand_dims(img_norm, axis=0)
np.shape(img_norm)
---
(1, 64, 64, 3)

训练过程中,对数据集做了归一化处理,推理时也要做同样的处理。

2.3 载入模型并推理

2.3.1 tiny-vgg 模型结构

# Create an instance of the model
filters = 10
tiny_vgg = Sequential([Conv2D(filters, (3, 3), input_shape=(64, 64, 3), name='conv_1_1'),Activation('relu', name='relu_1_1'),Conv2D(filters, (3, 3), name='conv_1_2'),Activation('relu', name='relu_1_2'),MaxPool2D((2, 2), name='max_pool_1'),Conv2D(filters, (3, 3), name='conv_2_1'),Activation('relu', name='relu_2_1'),Conv2D(filters, (3, 3), name='conv_2_2'),Activation('relu', name='relu_2_2'),MaxPool2D((2, 2), name='max_pool_2'),Flatten(name='flatten'),Dense(NUM_CLASS, activation='softmax', name='output')
])

2.3.2 运行推理

tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')
prediction = tiny_vgg.predict(img_norm)
prediction
---
array([[6.2276758e-02, 3.6967881e-03, 9.2534656e-06, 4.8701441e-01,3.6426269e-02, 2.9939638e-02, 7.1093095e-03, 2.9743392e-02,2.1278052e-02, 3.2250613e-01]], dtype=float32)

注意:模型的最后一层已经经过 softmax 计算,无需单独调用 softmax 计算概率:

sum = np.sum(prediction)
print(sum)
---
1.0

 2.4 置信度柱状图

fig = plt.figure(figsize=(18,6))# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围# 显示置信度数值
for i in range(len(y)):plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)plt.tight_layout()

2.5 预测标签

predict_label = class_names[np.argmax(prediction)]
print("类别: {}".format(predict_label))# 显示图片
plt.figure(figsize=(2, 2))
plt.imshow(img)
plt.axis('off')
plt.show()

3. 完整代码

import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2font = matplotlib.font_manager.FontProperties(fname="./SimHei.ttf")vali_dataset = tf.keras.preprocessing.image_dataset_from_directory('./dataset/class_10_val/val_images/',image_size=(64, 64),batch_size=32)class_names = vali_dataset.class_namesimg = cv2.imread('./dataset/class_10_train/橙子/n07747607_0.JPEG')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')prediction = tiny_vgg.predict(img_norm)fig = plt.figure(figsize=(18,6))# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围# 显示置信度数值
for i in range(len(y)):plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)plt.tight_layout()

4. 总结

本博文详继续介绍 Tiny-VGG 项目,对模型进行单张图片的推理,关键要点包括:

1). 数据处理与标签管理:通过 image_dataset_from_directory 方法加载数据,并提取类别名称作为标签,同时展示了如何保存和读取类别标签到/从文本文件。

2). 图片预处理:读取单个图片,并对其进行归一化处理,以匹配训练时的数据处理方式,确保模型能正确解读输入数据。

3). 模型加载与推理:加载预训练的Tiny-VGG模型,并对单张图片进行推理,获取预测结果。

4). 结果可视化:通过绘制图片和置信度柱状图来可视化模型的预测结果,使用中文标签和显示每个类别的置信度值。

5). 实用代码示例:提供了完整的代码示例,包括数据加载、模型推理和结果展示,方便读者理解和实际操作。
 

这篇关于Vitis AI 基本认知(Tiny-VGG 标签获取+预测后处理)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1117049

相关文章

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

基本知识点

1、c++的输入加上ios::sync_with_stdio(false);  等价于 c的输入,读取速度会加快(但是在字符串的题里面和容易出现问题) 2、lower_bound()和upper_bound() iterator lower_bound( const key_type &key ): 返回一个迭代器,指向键值>= key的第一个元素。 iterator upper_bou

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

AI行业应用(不定期更新)

ChatPDF 可以让你上传一个 PDF 文件,然后针对这个 PDF 进行小结和提问。你可以把各种各样你要研究的分析报告交给它,快速获取到想要知道的信息。https://www.chatpdf.com/

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联