Tensorflow入门实战 T05-运动鞋识别

2024-06-20 06:04

本文主要是介绍Tensorflow入门实战 T05-运动鞋识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

一、完整代码

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

(4)训练结果的精确度

三、遇到的问题


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这篇博客的主要内容是,关于运动鞋的识别。

运动鞋数据集包含训练集和测试集,共578张。

一、完整代码

from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")# 导入数据集
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p05_sport/sport_data"
data_dir = pathlib.Path(data_dir)  # 打印文件夹目录
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:",image_count)
roses = list(data_dir.glob('train/nike/*.jpg'))
result = PIL.Image.open(str(roses[0]))
# result.show()# 数据预处理
batch_size = 32
img_height = 224
img_width = 224"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_dir = str(data_dir) + "/train"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_dir,seed=123,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
test_dir = str(data_dir) + "/test/"
val_ds = tf.keras.preprocessing.image_dataset_from_directory(test_dir,seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)  # 打印结果: ['adidas', 'nike']# 可视化数据
plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.show()# 检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (32, 224, 224, 3)print(labels_batch.shape)  # (32,)break# 配置数据集
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 搭建神经网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""model = models.Sequential([keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样layers.Dropout(0.3),layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.3),layers.Flatten(),  # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取layers.Dense(len(class_names))  # 输出层,输出预期结果
])# model.summary()  # 打印网络结构# 设置动态学习率
# 设置初始学习率
initial_learning_rate = 0.1lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingepochs = 50# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=True)# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',min_delta=0.001,patience=20,verbose=1)# 模型训练
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer, earlystopper])# 模型评估图
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(len(loss))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')from PIL import Image
import numpy as np# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  # 这里选择你需要预测的图片
img = Image.open(str(data_dir) + "/test/nike/1.jpg")  # 这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])img_array = tf.expand_dims(image, 0) # /255.0  # 记得做归一化处理(与训练集处理方式保持一致)predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

模型早听了。

(4)训练结果的精确度

三、遇到的问题

在进行字符串拼接的时候,出现unsupported operand type(s) for +: 'PosixPath' and 'str'

查了下相关资料,添加str( )就可以。

原因:因为前面的data_dir 是经过pathlib.Path() 处理的。

添加str( ) 完美解决。

这篇关于Tensorflow入门实战 T05-运动鞋识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1077302

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联

滚雪球学Java(87):Java事务处理:JDBC的ACID属性与实战技巧!真有两下子!

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE啦,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好习惯,别被干货淹没了哦~ 🏆本文收录于「滚雪球学Java」专栏,专业攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎大家关注&&收藏!持续更新中,up!up!up!! 环境说明:Windows 10

poj 2104 and hdu 2665 划分树模板入门题

题意: 给一个数组n(1e5)个数,给一个范围(fr, to, k),求这个范围中第k大的数。 解析: 划分树入门。 bing神的模板。 坑爹的地方是把-l 看成了-1........ 一直re。 代码: poj 2104: #include <iostream>#include <cstdio>#include <cstdlib>#include <al