经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

2024-03-22 10:30

本文主要是介绍经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

1 ResNet的简述

  1. ResNet 提出了一种残差学习框架来解决网络退化问题,从而训练更深的网络。这种框架可以结合已有的各种网络结构,充分发挥二者的优势。

  2. ResNet以三种方式挑战了传统的神经网络架构:

    • ResNet 通过引入跳跃连接来绕过残差层,这允许数据直接流向任何后续层。

      这与传统的、顺序的pipeline 形成鲜明对比:传统的架构中,网络依次处理低级feature 到高级feature

    • ResNet 的层数非常深,高达1202层。而ALexNet 这样的架构,网络层数要小两个量级。

    • 通过实验发现,训练好的 ResNet 中去掉单个层并不会影响其预测性能。而训练好的AlexNet 等网络中,移除层会导致预测性能损失。

  3. ImageNet分类数据集中,拥有152层的残差网络,以3.75% top-5 的错误率获得了ILSVRC 2015 分类比赛的冠军。

  4. 很多证据表明:残差学习是通用的,不仅可以应用于视觉问题,也可应用于非视觉问题。

  5. 论文地址: https://arxiv.org/pdf/1512.03385.pdf

  6. 卷积神经网络领域的两次技术爆炸,第一次是AlexNet,第二次就是ResNet了。

1.1 网络退化问题

  • 1、理论上来讲网络深度越深越好。网络越深,提取的图片特征越多越丰富,但随之会带来很多的问题(通过Batch Normalization 在很大程度上解决),比如过拟合或者计算量爆炸、梯度消失、梯度爆炸等,导致网络在一定深度下就达到了局部最优解。

  • 2、ResNet 论文作者发现:随着网络的深度的增加,准确率达到饱和之后迅速下降,而这种下降不是由过拟合引起的。这称作网络退化问题。如果更深的网络训练误差更大,则说明是由于优化算法引起的:越深的网络,求解优化问题越难。如下所示:更深的网络导致更高的训练误差和测试误差。

在这里插入图片描述

  • 3、理论上讲,较深的模型不应该比和它对应的、较浅的模型更差。因为较深的模型是较浅的模型的超空间。较深的模型可以这样得到:先构建较浅的模型,然后添加很多恒等映射的网络层。实际上我们的较深的模型后面添加的不是恒等映射,而是一些非线性层。因此,退化问题表明:通过多个非线性层来近似横等映射可能是困难的

在这里插入图片描述

  • 4、针对这⼀问题,何恺明等⼈提出了残差⽹络(ResNet)。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经⽹络的设计。残差⽹络的核⼼思想是:每个附加层都应该更容易地包含原始函数作为其元素之⼀

1.2 残差块(residual blocks)

1.2.1 残差块的理解

在这里插入图片描述

1、假设需要学习的是映射 y = H(x),残差块使用堆叠的非线性层拟合残差:y = F(x,W) + x 。

其中:

  • x 和 y 是块的输入和输出向量。
  • F(x,W)是要学习的残差映射。因为 F(x,W) = H(x) - x,因此称F为残差。
  • + :通过快捷连接逐个元素相加来执行。快捷连接 指的是那些跳过一层或者更多层的连接。
    • 快捷连接简单的执行恒等映射,并将其输出添加到堆叠层的输出。
    • 快捷连接既不增加额外的参数,也不增加计算复杂度。
  • 相加之后通过非线性激活函数,这可以视作对整个残差块添加非线性,即 relu(y)

2、残差映射易于捕捉恒等映射的细微波动。比如5正常映射为5.1,加入残差后变成 5+0.1。此时输入变成5.2,对于没有残差结构的结果,影响仅为0.1/5.1 = 2%。而对于残差结构,变成 5+0.2 , 由0.1变成了0.2 影响为100%。

3、残差映射 H ( x ) = F ( x ) + x ,在反向传播的时候就变成了 H ′ ( x ) = F ′ ( x ) + 1,这里的加1也可以保证梯度消失现象

4、作者也证明了退化问题在任何数据集上都普遍存在。在imagenet上拿到冠军之后,迁移学习用到了coco同样拿到了好几个赛道的冠军,说明残差结构是普适的。最后又和VGG比了一下,比VGG深了8倍,计算复杂性却还比VGG小 。

1.2.2 残差函数F的形式的可变性

  • 层数可变:论文中的实验包含有两层堆叠、三层堆叠,实际任务中也可以包含更多层的堆叠。

    如果F只有一层,则残差块退化线性层:y = Wx + x 。此时对网络并没有什么提升。

  • 连接形式可变:不仅可用于全连接层,可也用于卷积层。此时F代表多个卷积层的堆叠,而最终的逐元素加法+ 在两个feature map 上逐通道进行。

    此时 x 也是一个feature map,而不再是一个向量。

1.2.3 残差学习成功的原因

学习残差F(x,W)比学习原始映射H(x)要更容易。

  • 1、当原始映射H就是一个恒等映射时, 就是一个F零映射。此时求解器只需要简单的将堆叠的非线性连接的权重推向零即可。

    实际任务中原始映射 H可能不是一个恒等映射:

    • 如果H 更偏向于恒等映射(而不是更偏向于非恒等映射),则F就是关于恒等映射的抖动,会更容易学习。
    • 如果原始映射H 更偏向于零映射,那么学习 本身要更容易。但是在实际应用中,零映射非常少见,因为它会导致输出全为0。
  • 2、如果原始映射H是一个非恒等映射,则可以考虑对残差模块使用缩放因子。如Inception-Resnet 中:在残差模块与快捷连接叠加之前,对残差进行缩放。注意:ResNet 作者在随后的论文中指出:不应该对恒等映射进行缩放。

  • 3、可以通过观察残差 F的输出来判断:如果F的输出均为0附近的、较小的数,则说明原始映射H更偏向于恒等映射;否则,说明原始映射H更偏向于非横等映射。

1.2.4 残差块代码实现

在这里插入图片描述

from torch import nn
from torch.nn import functional as F
import torch'''
⼀种是当use_1x1conv=False时,应⽤ReLU⾮线性函数之前,将输⼊添加到输出。
另⼀种是当use_1x1conv=True时,添加通过1 × 1卷积调整通道和分辨率ResNet沿⽤了VGG完整的3 × 3卷积层设计。
残差块⾥⾸先有2个有相同输出通道数的3 × 3卷积层。
每个卷积层后接⼀个批量规范化层和ReLU激活函数。
然后我们通过跨层数据通路,跳过这2个卷积运算,将输⼊直接加在最后的ReLU激活函数前。
这样的设计要求2个卷积层的输出与输⼊形状⼀样,从⽽使它们可以相加。如果想改变通道数,就需要引⼊⼀个额外的1 × 1卷积层来将输⼊变换成需要的形状后再做相加运算。
'''
class Residual(nn.Module):def __init__(self,input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)if __name__ == '__main__':blk = Residual(3, 3)X = torch.rand(4, 3, 6, 6)Y = blk(X)print(Y.shape)  # 输⼊和输出形状⼀致 torch.Size([4, 3, 6, 6])blk = Residual(3, 6, use_1x1conv=True, strides=2)Y = blk(X)print(Y.shape)  # 在增加输出通道数的同时,减半输出的高和宽 torch.Size([4, 6, 3, 3])

1.3 ResNet网络

1.3.1 四种plain 网络

plain 网络:一些简单网络结构的叠加,如下图所示。图中给出了四种plain 网络,它们的区别主要是网络深度不同。其中,输入图片尺寸 224x224 。

ResNet 简单的在plain 网络上添加快捷连接来实现。

FLOPsfloating point operations 的缩写,意思是浮点运算量,用于衡量算法/模型的复杂度。

FLOPSfloating point per second的缩写,意思是每秒浮点运算次数,用于衡量计算速度。

在这里插入图片描述

相对于输入的feature map,残差块的输出feature map 尺寸可能会发生变化:

  • 输出 feature map 的通道数增加,此时需要扩充快捷连接的输出feature map 。否则快捷连接的输出 feature map 无法和残差块的feature map 累加。

    有两种扩充方式:

    • 直接通过 0 来填充需要扩充的维度。
    • 通过1x1 卷积来扩充维度。
  • 输出 feature map 的尺寸减半。此时需要对快捷连接执行步长为 2 的池化/卷积:如果快捷连接已经采用 1x1 卷积,则该卷积步长为2 ;否则采用步长为 2 的最大池化 。

1.3.2 模型预测能力

VGG-1934层 plain 网络Resnet-34
计算复杂度(FLOPs)19.6 billion3.5 billion3.6 billion

ImageNet 验证集上执行10-crop 测试的结果。

  • A 类模型:快捷连接中,所有需要扩充的维度的填充 0 。
  • B 类模型:快捷连接中,所有需要扩充的维度通过1x1 卷积来扩充。
  • C 类模型:所有快捷连接都通过1x1 卷积来执行线性变换。

C 优于BB 优于A。但是 C 引入更多的参数,相对于这种微弱的提升,性价比较低。所以后续的ResNet 均采用 B 类模型。

模型top-1 误差率top-5 误差率
VGG-1628.07%9.33%
GoogleNet-9.15%
PReLU-net24.27%7.38%
plain-3428.54%10.02%
ResNet-34 A25.03%7.76%
ResNet-34 B24.52%7.46%
ResNet-34 C24.19%7.40%
ResNet-5022.85%6.71%
ResNet-10121.75%6.05%
ResNet-15221.43%5.71%

1.3.3 ResNet-18实现

import torch.nn as nn
import torch
from _06_Residual import Residualclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.model = self.get_net()def forward(self, X):X = self.model(X)return Xdef get_net(self):'''ResNet的前两层跟GoogLeNet中的⼀样:在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最⼤汇聚层。不同之处在于ResNet每个卷积层后增加了批量规范化层。'''b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))'''GoogLeNet在后⾯接了4个由Inception块组成的模块。ResNet则使⽤4个由残差块组成的模块,每个模块使⽤若⼲个同样输出通道数的残差块。第⼀个模块的通道数同输⼊通道数⼀致。由于之前已经使⽤了步幅为2的最⼤汇聚层,所以⽆须减⼩⾼和宽。之后的每个模块在第⼀个残差块⾥将上⼀个模块的通道数翻倍,并将⾼和宽减半。'''b2 = nn.Sequential(*self.resnet_block(64, 64, 2, first_block=True))b3 = nn.Sequential(*self.resnet_block(64, 128, 2))b4 = nn.Sequential(*self.resnet_block(128, 256, 2))b5 = nn.Sequential(*self.resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))return netdef resnet_block(self, input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkif __name__ == '__main__':net = ResNet18()X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)for layer in net.model:X = layer(X)print(layer.__class__.__name__, 'output shape:', X.shape)
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 10])

2 ResNet-18在Fashion-MNIST数据集上的应用示例

2.1 创建ResNet网络模型

如1.2.4及1.3.3代码所示。

2.2 读取Fashion-MNIST数据集

其他所有的函数,与经典神经网络(1)LeNet及其在Fashion-MNIST数据集上的应用完全一致。

batch_size = 256# 为了使Fashion-MNIST上的训练短⼩精悍,将输⼊的⾼和宽从224降到96,简化计算
train_iter,test_iter = get_mnist_data(batch_size,resize=96)

2.3 在GPU上进行模型训练

from _06_ResNet18 import ResNet18# 初始化模型
net = ResNet18()lr, num_epochs = 0.05, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())

在这里插入图片描述

3 ResNet18微调进行水果图像分类

3.1 爬取水果图像数据

import requests
import urllib3
urllib3.disable_warnings()import time
import os
import random
import pandas as pd
import shutil
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")# 进度条库
from tqdm import tqdm
# http请求参数
cookies = {'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621','BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B','PSTM': '1646905430','BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1','BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598','H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061','BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5','H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f','BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5','H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f','indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D','cleanHistoryStatus': '0','BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1','BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm','BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm','ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==','delPer': '0','PSINO': '2','BA_HECTOR': '8h24a024042g05alup1h3g0aq0q'
}headers = {'Connection': 'keep-alive','sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"','Accept': 'text/plain, */*; q=0.01','X-Requested-With': 'XMLHttpRequest','sec-ch-ua-mobile': '?0','User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36','sec-ch-ua-platform': '"macOS"','Sec-Fetch-Site': 'same-origin','Sec-Fetch-Mode': 'cors','Sec-Fetch-Dest': 'empty','Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=&copyright=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97','Accept-Language': 'zh-CN,zh;q=0.9'
}def download_single_class(file_path, keyword, DOWNLOAD_NUM=100):if not os.path.exists(file_path + "/dataset"):os.makedirs(file_path + "/dataset")print(f'新建{file_path}/dataset文件夹')if not os.path.exists(file_path + "/dataset/" + keyword):os.makedirs(file_path + "/dataset/"+ keyword)print('新建文件夹:{}/dataset/{}'.format(file_path, keyword))else:print('文件夹:{}/dataset/{}已经存在,之后将爬取的图片保存到该文件夹中'.format(file_path, keyword))count = 1with tqdm(total=DOWNLOAD_NUM, position=0, leave=True) as pbar:# 爬取第几张num = 1# 是否继续爬取FLAG = Truewhile FLAG:page = 30 * countparams = (('tn', 'resultjson_com'),('logid', '12508239107856075440'),('ipn', 'rj'),('ct', '201326592'),('is', ''),('fp', 'result'),('fr', ''),('word', f'{keyword}'),('queryWord', f'{keyword}'),('cl', '2'),('lm', '-1'),('ie', 'utf-8'),('oe', 'utf-8'),('adpicid', ''),('st', '-1'),('z', ''),('ic', ''),('hd', ''),('latest', ''),('copyright', ''),('s', ''),('se', ''),('tab', ''),('width', ''),('height', ''),('face', '0'),('istype', '2'),('qc', ''),('nc', '1'),('expermode', ''),('nojc', ''),('isAsync', ''),('pn', f'{page}'),('rn', '30'),('gsm', '1e'),('1647838001666', ''),)response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params,cookies=cookies)if response.status_code == 200:try:json_data = response.json().get("data")if json_data:for x in json_data:type = x.get("type")if type not in ["gif"]:img = x.get("thumbURL")fromPageTitleEnc = x.get("fromPageTitleEnc")try:resp = requests.get(url=img, verify=False)time.sleep(1)# print(f"链接 {img}")# 保存文件名# file_save_path = f'dataset/{keyword}/{num}-{fromPageTitleEnc}.{type}'file_save_path = file_path + f'/dataset/{keyword}/{num}.{type}'with open(file_save_path, 'wb') as f:f.write(resp.content)f.flush()# print('第 {} 张图像 {} 爬取完成'.format(num, fromPageTitleEnc))num += 1pbar.update(1)  # 进度条更新# 爬取数量达到要求if num > DOWNLOAD_NUM:FLAG = Falseprint('{} 张图像爬取完毕'.format(num - 1))breakexcept Exception:passexcept:passelse:breakcount += 1
# 测试爬取香蕉
file_path = 'D:\python\kaggle\pictures_classfication_data'
download_single_class(file_path,'香蕉', DOWNLOAD_NUM=2)
# 爬取多类水果
class_list = ['苹果','梨','葡萄','火龙果','大枣','柑橘','柚子','桃','杏','西瓜','荔枝','甘蔗','柿子','羊角蜜','香蕉','菠萝','芒果','哈密瓜','石榴','椰子']for class_name in class_list:download_single_class(file_path,class_name)
新建文件夹:D:\python\kaggle\pictures_classfication_data/dataset/苹果100%|██████████| 100/100 [03:17<00:00,  1.98s/it]100 张图像爬取完毕新建文件夹:D:\python\kaggle\pictures_classfication_data/dataset/椰子100%|██████████| 100/100 [02:57<00:00,  1.77s/it]100 张图像爬取完毕

3.2 划分为训练集和测试集

file_path = file_path + '/dataset'classes = os.listdir(file_path)# 创建 train 文件夹
os.mkdir(os.path.join(file_path, 'train'))# 创建 test 文件夹
os.mkdir(os.path.join(file_path, 'val'))# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:os.mkdir(os.path.join(file_path, 'train', fruit))os.mkdir(os.path.join(file_path, 'val', fruit))
test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子,便于复现
df = pd.DataFrame()print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))for fruit in classes: # 遍历每个类别# 读取该类别的所有图像文件名old_dir = os.path.join(file_path, fruit)images_filename = os.listdir(old_dir)random.shuffle(images_filename) # 随机打乱# 划分训练集和测试集testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名# 移动图像至 test 目录for image in testset_images:old_img_path = os.path.join(file_path, fruit, image)         # 获取原始文件路径new_test_path = os.path.join(file_path, 'val', fruit, image) # 获取 test 目录的新文件路径shutil.move(old_img_path, new_test_path) # 移动文件# 移动图像至 train 目录for image in trainset_images:old_img_path = os.path.join(file_path, fruit, image)           # 获取原始文件路径new_train_path = os.path.join(file_path, 'train', fruit, image) # 获取 train 目录的新文件路径shutil.move(old_img_path, new_train_path) # 移动文件# 删除旧文件夹assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走shutil.rmtree(old_dir) # 删除文件夹# 输出每一类别的数据个数print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))# 保存到表格中df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
        类别              训练集数据个数            测试集数据个数      
# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.head()
df.to_csv('数据量统计.csv', index=False)
classtestsettrainsettotal
0哈密瓜20.080.0100.0
1大枣20.080.0100.0
220.080.0100.0
3柑橘20.080.0100.0
4柚子20.080.0100.0

3.3 可视化图像文件

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes#读取图像,解决imread不能读取中文路径路径的问题
def cv_imread(file_path):cv_img = cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)return cv_img
# 读取训练集【西瓜】文件夹所有的图像
folder_path = os.path.join(file_path ,'train' ,  '西瓜')images = []
for each_img in os.listdir(folder_path):img_path = os.path.join(folder_path, each_img)img_bgr = cv_imread(img_path)img_rgb = cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)images.append(img_rgb)show_images([images[i] for i in range(32)],num_rows=4, num_cols=8, scale=1.0)

在这里插入图片描述

3.4 利用ResNet-18迁移学习,训练数据集

'''
将dataset水果分类打成zip压缩包,上传的linux机器上,用GPU训练Linux下的默认编码是UTF8,Windows下生成的zip文件中的编码是GBK/GB2312等.zip文件在Linux下解压时出现乱码问题.执行一下命令:
unzip -O GB18030 dataset.zip
'''
file_path = '/root/autodl-fs/data/fruit20/dataset'
def try_gpu(i=0):if torch.cuda.device_count() >= i + 1:return torch.device(f'cuda:{i}')return torch.device('cpu')'''
1、图像预处理
'''
from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
'''
2、载入水果图像分类数据集
'''
train_path = os.path.join(file_path, 'train')
test_path = os.path.join(file_path, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
训练集路径 /root/autodl-fs/data/fruit20/dataset/train
测试集路径 /root/autodl-fs/data/fruit20/dataset/val
训练集图像数量 1600
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
测试集图像数量 400
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
'''
3、类别索引  映射字典
'''
# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
idx_to_labels
{0: '哈密瓜',1: '大枣',2: '杏',3: '柑橘',4: '柚子',5: '柿子',6: '桃',7: '梨',8: '椰子',9: '火龙果',10: '甘蔗',11: '石榴',12: '羊角蜜',13: '芒果',14: '苹果',15: '荔枝',16: '菠萝',17: '葡萄',18: '西瓜',19: '香蕉'}
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)
'''
4、加载数据集
'''
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_iter = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)# 测试集的数据加载器
test_iter = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=0)
'''
5、微调最后一层,创建resnet-18模型
'''
net = torchvision.models.resnet18(pretrained=True) # 载入预训练模型# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
net.fc = nn.Linear(net.fc.in_features, n_class)
'''
6、模型训练
'''import torch.nn as nn
from AccumulatorClass import Accumulatordef accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使⽤GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval() # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
from AnimatorClass import Animator
from TimerClass import Timerdef train_ch(net, train_iter, test_iter, num_epochs, lr, device):"""⽤GPU训练模型"""print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.fc.parameters(), lr=lr)# 只微调训练最后一层全连接层的参数,其它层冻结# optimizer = torch.optim.Adam(net.fc.parameters())# 学习率降低策略# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)# 交叉熵损失loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)num_batches = len(train_iter)best_test_accuracy = 0.0for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()# lr_scheduler.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)if test_acc > best_test_accuracy:# 删除旧的最佳模型文件(如有)old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)if os.path.exists(old_best_checkpoint_path):os.remove(old_best_checkpoint_path)# 保存新的最佳模型文件new_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(test_acc)torch.save(net, new_best_checkpoint_path)print('保存新的最佳模型', 'checkpoint/best-{:.3f}.pth'.format(test_acc))best_test_accuracy = test_accanimator.add(epoch + 1, (None, None, test_acc))print(f'best_test_accuracy = {best_test_accuracy:.3f}')    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
# 初始化模型
lr, num_epochs = 0.1, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())
best_test_accuracy = 0.840
loss 0.546, train acc 0.832, test acc 0.810
565.2 examples/sec on cuda:0

在这里插入图片描述

3.5 模型预测

best_test_accuracy = 0.840# 载入最佳模型作为当前模型
net = torch.load('checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))
net.to(try_gpu())test_iter = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)def get_fruit_labels(labels):"""返回fruit20数据集的⽂本标签"""text_labels = test_dataset.classesreturn [text_labels[int(i)] for i in labels]def predict(net,test_iter, n=10):for X,y in test_iter:trues = get_fruit_labels(y[0:n])outputs = net(X.to(try_gpu())) # 输入模型,执行前向预测_, preds = torch.max(outputs, 1)preds = get_fruit_labels(preds.cpu().numpy()[0:n])print('trues:',trues)print('preds:',preds)breakpredict(net,test_iter)
trues: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']
preds: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']

这篇关于经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/835116

相关文章

Redis的数据过期策略和数据淘汰策略

《Redis的数据过期策略和数据淘汰策略》本文主要介绍了Redis的数据过期策略和数据淘汰策略,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录一、数据过期策略1、惰性删除2、定期删除二、数据淘汰策略1、数据淘汰策略概念2、8种数据淘汰策略

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

Python给Excel写入数据的四种方法小结

《Python给Excel写入数据的四种方法小结》本文主要介绍了Python给Excel写入数据的四种方法小结,包含openpyxl库、xlsxwriter库、pandas库和win32com库,具有... 目录1. 使用 openpyxl 库2. 使用 xlsxwriter 库3. 使用 pandas 库

SpringBoot定制JSON响应数据的实现

《SpringBoot定制JSON响应数据的实现》本文主要介绍了SpringBoot定制JSON响应数据的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们... 目录前言一、如何使用@jsonView这个注解?二、应用场景三、实战案例注解方式编程方式总结 前言

使用Python在Excel中创建和取消数据分组

《使用Python在Excel中创建和取消数据分组》Excel中的分组是一种通过添加层级结构将相邻行或列组织在一起的功能,当分组完成后,用户可以通过折叠或展开数据组来简化数据视图,这篇博客将介绍如何使... 目录引言使用工具python在Excel中创建行和列分组Python在Excel中创建嵌套分组Pyt

在Rust中要用Struct和Enum组织数据的原因解析

《在Rust中要用Struct和Enum组织数据的原因解析》在Rust中,Struct和Enum是组织数据的核心工具,Struct用于将相关字段封装为单一实体,便于管理和扩展,Enum用于明确定义所有... 目录为什么在Rust中要用Struct和Enum组织数据?一、使用struct组织数据:将相关字段绑

在Mysql环境下对数据进行增删改查的操作方法

《在Mysql环境下对数据进行增删改查的操作方法》本文介绍了在MySQL环境下对数据进行增删改查的基本操作,包括插入数据、修改数据、删除数据、数据查询(基本查询、连接查询、聚合函数查询、子查询)等,并... 目录一、插入数据:二、修改数据:三、删除数据:1、delete from 表名;2、truncate

Java实现Elasticsearch查询当前索引全部数据的完整代码

《Java实现Elasticsearch查询当前索引全部数据的完整代码》:本文主要介绍如何在Java中实现查询Elasticsearch索引中指定条件下的全部数据,通过设置滚动查询参数(scrol... 目录需求背景通常情况Java 实现查询 Elasticsearch 全部数据写在最后需求背景通常情况下

5分钟获取deepseek api并搭建简易问答应用

《5分钟获取deepseekapi并搭建简易问答应用》本文主要介绍了5分钟获取deepseekapi并搭建简易问答应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需... 目录1、获取api2、获取base_url和chat_model3、配置模型参数方法一:终端中临时将加

Java中注解与元数据示例详解

《Java中注解与元数据示例详解》Java注解和元数据是编程中重要的概念,用于描述程序元素的属性和用途,:本文主要介绍Java中注解与元数据的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参... 目录一、引言二、元数据的概念2.1 定义2.2 作用三、Java 注解的基础3.1 注解的定义3.2 内