高层API助你快速上手深度学习----【第二课作业】十二生肖分类详解

本文主要是介绍高层API助你快速上手深度学习----【第二课作业】十二生肖分类详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

① 问题定义

十二生肖分类的本质是图像分类任务,我们采用CNN网络结构进行相关实践。

② 数据准备

2.1 解压缩数据集

我们将网上获取的数据集以压缩包的方式上传到aistudio数据集中,并加载到我们的项目内。

在使用之前我们进行数据集压缩包的一个解压。

!unzip -q -o data/data68755/signs.zip

2.2 数据标注

我们先看一下解压缩后的数据集长成什么样子。

.
├── test
│   ├── dog
│   ├── dragon
│   ├── goat
│   ├── horse
│   ├── monkey
│   ├── ox
│   ├── pig
│   ├── rabbit
│   ├── ratt
│   ├── rooster
│   ├── snake
│   └── tiger
├── train
│   ├── dog
│   ├── dragon
│   ├── goat
│   ├── horse
│   ├── monkey
│   ├── ox
│   ├── pig
│   ├── rabbit
│   ├── ratt
│   ├── rooster
│   ├── snake
│   └── tiger
└── valid├── dog├── dragon├── goat├── horse├── monkey├── ox├── pig├── rabbit├── ratt├── rooster├── snake└── tiger

数据集分为train、valid、test三个文件夹,每个文件夹内包含12个分类文件夹,每个分类文件夹内是具体的样本图片。

我们对这些样本进行一个标注处理,最终生成train.txt/valid.txt/test.txt三个数据标注文件。

import io
import os
from PIL import Image
from config import get# 数据集根目录
DATA_ROOT = 'signs'# 标签List
LABEL_MAP = get('LABEL_MAP')# 标注生成函数
def generate_annotation(mode):# 建立标注文件with open('{}/{}.txt'.format(DATA_ROOT, mode), 'w') as f:# 对应每个用途的数据文件夹,train/valid/testtrain_dir = '{}/{}'.format(DATA_ROOT, mode)# 遍历文件夹,获取里面的分类文件夹for path in os.listdir(train_dir):# 标签对应的数字索引,实际标注的时候直接使用数字索引label_index = LABEL_MAP.index(path)# 图像样本所在的路径image_path = '{}/{}'.format(train_dir, path)# 遍历所有图像for image in os.listdir(image_path):# 图像完整路径和名称image_file = '{}/{}'.format(image_path, image)try:# 验证图片格式是否okwith open(image_file, 'rb') as f_img:image = Image.open(io.BytesIO(f_img.read()))image.load()if image.mode == 'RGB':f.write('{}\t{}\n'.format(image_file, label_index))except:continuegenerate_annotation('train')  # 生成训练集标注文件
generate_annotation('valid')  # 生成验证集标注文件
generate_annotation('test')   # 生成测试集标注文件

2.3 数据集定义

接下来我们使用标注好的文件进行数据集类的定义,方便后续模型训练使用。

2.3.1 导入相关库

import paddle
import numpy as np
from config import getpaddle.__version__
'2.0.0'

2.3.2 导入数据集的定义实现

我们数据集的代码实现是在dataset.py中。

from dataset import ZodiacDataset

2.3.3 实例化数据集类

根据所使用的数据集需求实例化数据集类,并查看总样本量。

train_dataset = ZodiacDataset(mode='train')
valid_dataset = ZodiacDataset(mode='valid')print('训练数据集:{}张;验证数据集:{}张'.format(len(train_dataset), len(valid_dataset)))
训练数据集:7096张;验证数据集:639张

③ 模型选择和开发

3.1 网络构建

本次我们使用ResNet50网络来完成我们的案例实践。

1)ResNet系列网络

2)ResNet50结构

3)残差区块

4)ResNet其他版本

# 请补齐模型实例化代码network=paddle.vision.models.resnet101(num_classes=get('num_classes'), pretrained=True)
100%|██████████| 263160/263160 [00:03<00:00, 68870.66it/s]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.weight. fc.weight receives a shape [2048, 1000], but the expected shape is [2048, 12].warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.bias. fc.bias receives a shape [1000], but the expected shape is [12].warnings.warn(("Skip loading for {}. ".format(key) + str(err)))

模型可视化

model = paddle.Model(network)
model.summary((-1, ) + tuple(get('image_shape')))

④ 模型训练和优化

EPOCHS = get('epochs')
BATCH_SIZE = get('batch_size')# 请补齐模型训练过程代码
def create_optim(parameters):step_each_epoch = get('total_images') // get('batch_size')lr = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=get('LEARNING_RATE.params.lr'),T_max=step_each_epoch * EPOCHS)return paddle.optimizer.Momentum(learning_rate=lr,parameters=parameters,weight_decay=paddle.regularizer.L2Decay(get('OPTIMIZER.regularizer.factor')))# 模型训练配置
model.prepare(create_optim(network.parameters()),  # 优化器paddle.nn.CrossEntropyLoss(),        # 损失函数paddle.metric.Accuracy(topk=(1, 5))) # 评估指标# 训练可视化VisualDL工具的回调函数
visualdl = paddle.callbacks.VisualDL(log_dir='visualdl_log')# 启动模型全流程训练
model.fit(train_dataset,            # 训练数据集valid_dataset,            # 评估数据集epochs=EPOCHS,            # 总的训练轮次batch_size=BATCH_SIZE,    # 批次计算的样本量大小shuffle=True,             # 是否打乱样本集verbose=1,                # 日志展示格式save_dir='./chk_points/', # 分阶段的训练模型存储路径callbacks=[visualdl])     # 回调函数使用
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/20/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:636: UserWarning: When training, we now always track global mean and variance."When training, we now always track global mean and variance.")step 111/111 [==============================] - loss: 0.2994 - acc_top1: 0.8129 - acc_top5: 0.9569 - 2s/step         
save checkpoint at /home/aistudio/chk_points/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2774 - acc_top1: 0.9233 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 2/20
step 111/111 [==============================] - loss: 0.4724 - acc_top1: 0.9012 - acc_top5: 0.9873 - 2s/step         
save checkpoint at /home/aistudio/chk_points/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1623 - acc_top1: 0.9264 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 3/20
step 111/111 [==============================] - loss: 0.2196 - acc_top1: 0.9088 - acc_top5: 0.9893 - 2s/step         
save checkpoint at /home/aistudio/chk_points/2
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.3726 - acc_top1: 0.9452 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 4/20
step 111/111 [==============================] - loss: 0.1125 - acc_top1: 0.9218 - acc_top5: 0.9897 - 2s/step         
save checkpoint at /home/aistudio/chk_points/3
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1343 - acc_top1: 0.9484 - acc_top5: 0.9922 - 2s/step
Eval samples: 639
Epoch 5/20
step 111/111 [==============================] - loss: 0.2677 - acc_top1: 0.9308 - acc_top5: 0.9903 - 2s/step         
save checkpoint at /home/aistudio/chk_points/4
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2783 - acc_top1: 0.9499 - acc_top5: 0.9937 - 2s/step
Eval samples: 639
Epoch 6/20
step 111/111 [==============================] - loss: 0.2177 - acc_top1: 0.9312 - acc_top5: 0.9927 - 2s/step         
save checkpoint at /home/aistudio/chk_points/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.3158 - acc_top1: 0.9515 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 7/20
step 111/111 [==============================] - loss: 0.0833 - acc_top1: 0.9408 - acc_top5: 0.9938 - 2s/step         
save checkpoint at /home/aistudio/chk_points/6
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1721 - acc_top1: 0.9484 - acc_top5: 1.0000 - 2s/step
Eval samples: 639
Epoch 8/20
step 111/111 [==============================] - loss: 0.1454 - acc_top1: 0.9474 - acc_top5: 0.9937 - 2s/step         
save checkpoint at /home/aistudio/chk_points/7
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.3457 - acc_top1: 0.9593 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 9/20
step 111/111 [==============================] - loss: 0.0632 - acc_top1: 0.9501 - acc_top5: 0.9938 - 2s/step         
save checkpoint at /home/aistudio/chk_points/8
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2145 - acc_top1: 0.9671 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 10/20
step 111/111 [==============================] - loss: 0.1419 - acc_top1: 0.9507 - acc_top5: 0.9948 - 2s/step         
save checkpoint at /home/aistudio/chk_points/9
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1284 - acc_top1: 0.9687 - acc_top5: 0.9984 - 2s/step
Eval samples: 639
Epoch 11/20
step 111/111 [==============================] - loss: 0.1837 - acc_top1: 0.9567 - acc_top5: 0.9962 - 2s/step         
save checkpoint at /home/aistudio/chk_points/10
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2203 - acc_top1: 0.9624 - acc_top5: 0.9984 - 2s/step
Eval samples: 639
Epoch 12/20
step 111/111 [==============================] - loss: 0.0954 - acc_top1: 0.9631 - acc_top5: 0.9961 - 2s/step         
save checkpoint at /home/aistudio/chk_points/11
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1996 - acc_top1: 0.9609 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 13/20
step 111/111 [==============================] - loss: 0.1991 - acc_top1: 0.9612 - acc_top5: 0.9946 - 2s/step         
save checkpoint at /home/aistudio/chk_points/12
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.1538 - acc_top1: 0.9562 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 14/20
step 111/111 [==============================] - loss: 0.2008 - acc_top1: 0.9643 - acc_top5: 0.9966 - 2s/step         
save checkpoint at /home/aistudio/chk_points/13
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2580 - acc_top1: 0.9656 - acc_top5: 0.9984 - 2s/step
Eval samples: 639
Epoch 15/20
step 111/111 [==============================] - loss: 0.1571 - acc_top1: 0.9656 - acc_top5: 0.9965 - 2s/step         
save checkpoint at /home/aistudio/chk_points/14
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2078 - acc_top1: 0.9577 - acc_top5: 1.0000 - 2s/step
Eval samples: 639
Epoch 16/20
step 111/111 [==============================] - loss: 0.0766 - acc_top1: 0.9656 - acc_top5: 0.9972 - 2s/step         
save checkpoint at /home/aistudio/chk_points/15
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2247 - acc_top1: 0.9546 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 17/20
step 111/111 [==============================] - loss: 0.0959 - acc_top1: 0.9649 - acc_top5: 0.9969 - 2s/step         
save checkpoint at /home/aistudio/chk_points/16
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2885 - acc_top1: 0.9593 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
Epoch 18/20
step 111/111 [==============================] - loss: 0.0430 - acc_top1: 0.9682 - acc_top5: 0.9968 - 2s/step         
save checkpoint at /home/aistudio/chk_points/17
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2386 - acc_top1: 0.9577 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 19/20
step 111/111 [==============================] - loss: 0.1366 - acc_top1: 0.9659 - acc_top5: 0.9968 - 2s/step         
save checkpoint at /home/aistudio/chk_points/18
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2560 - acc_top1: 0.9562 - acc_top5: 0.9969 - 2s/step
Eval samples: 639
Epoch 20/20
step 111/111 [==============================] - loss: 0.0891 - acc_top1: 0.9666 - acc_top5: 0.9970 - 2s/step         
save checkpoint at /home/aistudio/chk_points/19
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10/10 [==============================] - loss: 0.2363 - acc_top1: 0.9609 - acc_top5: 0.9953 - 2s/step
Eval samples: 639
save checkpoint at /home/aistudio/chk_points/final

模型存储

将我们训练得到的模型进行保存,以便后续评估和测试使用。

model.save(get('model_save_dir'))

⑤ 模型评估和测试

5.1 批量预测测试

5.1.1 测试数据集

predict_dataset = ZodiacDataset(mode='test')
print('测试数据集样本量:{}'.format(len(predict_dataset)))
测试数据集样本量:646

5.1.2 执行预测

from paddle.static import InputSpec# 请补充网络结构# 网络结构示例化
network = paddle.vision.models.resnet101(num_classes=get('num_classes'))# 模型封装
model_2 = paddle.Model(network, inputs=[InputSpec(shape=[-1] + get('image_shape'), dtype='float32', name='image')])# 训练好的模型加载
model_2.load(get('model_save_dir'))# 模型配置
model_2.prepare()# 执行预测
result = model_2.predict(predict_dataset)
Predict begin.../opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn (isinstance(seq, collections.Sequence) andstep 646/646 [==============================] - 278ms/step        
Predict samples: 646
# 样本映射
LABEL_MAP = get('LABEL_MAP')# 随机取样本展示
indexs = [2, 38, 56, 92, 100, 303]for idx in indexs:predict_label = np.argmax(result[0][idx])real_label = predict_dataset[idx][1]print('样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, LABEL_MAP[real_label], LABEL_MAP[predict_label]))
样本ID:2, 真实标签:pig, 预测值:pig
样本ID:38, 真实标签:pig, 预测值:pig
样本ID:56, 真实标签:ratt, 预测值:ratt
样本ID:92, 真实标签:ratt, 预测值:ratt
样本ID:100, 真实标签:ratt, 预测值:ratt
样本ID:303, 真实标签:snake, 预测值:snake

⑥ 模型部署

model_2.save('infer/zodiac', training=False)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/vision/models/resnet.py:145
The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.op_type, op_type, EXPRESSION_MAP[method_name]))

这篇关于高层API助你快速上手深度学习----【第二课作业】十二生肖分类详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/414874

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

作业提交过程之HDFSMapReduce

作业提交全过程详解 (1)作业提交 第1步:Client调用job.waitForCompletion方法,向整个集群提交MapReduce作业。 第2步:Client向RM申请一个作业id。 第3步:RM给Client返回该job资源的提交路径和作业id。 第4步:Client提交jar包、切片信息和配置文件到指定的资源提交路径。 第5步:Client提交完资源后,向RM申请运行MrAp

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

电脑桌面文件删除了怎么找回来?别急,快速恢复攻略在此

在日常使用电脑的过程中,我们经常会遇到这样的情况:一不小心,桌面上的某个重要文件被删除了。这时,大多数人可能会感到惊慌失措,不知所措。 其实,不必过于担心,因为有很多方法可以帮助我们找回被删除的桌面文件。下面,就让我们一起来了解一下这些恢复桌面文件的方法吧。 一、使用撤销操作 如果我们刚刚删除了桌面上的文件,并且还没有进行其他操作,那么可以尝试使用撤销操作来恢复文件。在键盘上同时按下“C

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof