基于Mindspore,通过Resnet50迁移学习实现猫十二分类

2024-03-07 14:44

本文主要是介绍基于Mindspore,通过Resnet50迁移学习实现猫十二分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用平台介绍

使用平台:启智AI协作平台
使用数据集:百度猫十二分类

数据集介绍

有cat_12_train和cat_12_test和train_list.txt
train_list.txt内有每张图片所对应的标签

Minspore部分操作科普

数据集加载

Mindspore加载图片数据集就直接调整成这种格式就行,然后可以用这个函数加载,自动生成两个列,一列是图片,一列是标签;ImageFolderDataset函数会自动读取和处理数据集,标签就是文件夹的名称
在这里插入图片描述
在这里插入图片描述

数据处理

map函数里可以一键进行处理和映射,定义好数据处理函数,直接把路径和标签Map处理,后面可以带上是训练集还是测试集的标签;,要处理图片就指定 input_columns 参数为image,这个是前面数据读取形成的;
本项目这里用的是ImageFolderDataset,可以自动生成图片对应数据的两个列,要是相对数据处理就设置为数据列名即可
在这里插入图片描述

数据批处理和重复

batch就是批处理,把数据分成指定数量的一个个批次
在这里插入图片描述
最后repeat对数据进行重复
在这里插入图片描述

整体数据处理流程

就是读取数据形成数据和标签对应的列(读取数据函数有很多),然后定义数据预处理函数,在map函数里一键映射,指定要处理的列一键处理,最后对数据进行批次划分,就拿到可以放进训练网络函数的规范数据集了。当然使用时候还要用create_tuple_iterator或者create_dict_iterator函数形成可迭代的数据集。

迁移学习

在迁移学习中,固定特征训练和模型微调都是常用的技术

固定特征训练

在源任务上训练一个模型,并将其应用到目标任务上。在这个过程中,模型的特征提取器是固定的,只对输出层进行调整。这种方法可以利用源任务中已经学习到的特征,从而减少目标任务的训练时间和数据需求。固定特征训练通常适用于目标任务和源任务具有相似的特征空间,并且目标任务的数据量较小的情况。

模型微调

使用预训练模型作为初始模型,并在目标任务的数据集上进行进步训练。在微调阶段,可以根据目标任务的数据和特定要求调整模型的参数使其适应目标任务。模型微调的主要目的是通过在目标任务上的有限训练来调整预训练模型,以取得更好的性能。

举例

以训练一个猫、狗分类器为例,固定特征训练是指在一个大型的猫狗图片数据集上训练一个通用的图像识别模型,然后将该模型应用于特定的猫、狗分类任务。在这个过程中,我们只需要调整模型的输出层,使其能够正确地对猫和狗进行分类。而模型微调则是指使用一个已经在大型数据集上训练好的通用图像识别模型,然后在特定的猫、狗图片数据集上进行进一步训练,以优化模型的性能。在这个过程中我们可以根据猫、狗图片的特点来调整模型的参数,使其能够更好地识别猫和狗。

数据处理过程

由于本项目采用的数据集是百度所提供的猫十二分类,要使用ImageFolderDataset的话,形式不太匹配

现有形式

cat_12_train和cat_12_test里面都是一张一张的图片
train_list.txt内有每张图片所对应的标签
在这里插入图片描述

处理后形式

train和val文件夹内分别有十二个子文件夹,代表12类猫,每个子文件夹内又有一张张的图片
在这里插入图片描述在这里插入图片描述

处理代码

这里有相关代码进行自动划分,但是对于训练集和测试集的划分,我直接采用了手动操作,也可以用代码来实现的;

# 处理异常图片
dir_lit = os.listdir('./work/cat_12_train/')
# dir_lit为一个列表,里面是一张张图片的名称
for list in dir_lit:
# list是图片名称,这里的操作是把这个图片形成一个个的路径img_path=os.path.join('./work/cat_12_train/',list)print(img_path)# 如果不是RGB那就转换为RGBimg=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)dir_lit = os.listdir('./work/cat_12_test/')
for list in dir_lit:img_path=os.path.join('./work/cat_12_test/',list)img=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)
# 整理数据格式
# 创建12个文件夹分别对应标签
path='./work/MyDataset/'
for i in range(12):if not os.path.exists(path+str(i)):os.mkdir(path+str(i))else:continue
#读取每一行
with open(f'./work/train_list.txt','r')as f:img_path=f.readlines()print(img_path)# 里面是一个个的'cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg\t0\n'# 把对应文件放到对应标签文件夹下
for img in img_path:# 拿取每一张图片路径# print(img)img_src= img.split('\t')[0]# img_src为一个个图片路径rel_src= img_src.split('cat_12_train/')[1]# rel_src为图片名称img_label = img.split('\t')[1]img_label = img_label.split('\n')[0]# img_label为图片标签print(img_src)print(rel_src)print(img_label)# os.system(f'cp ./work/{img_src} ./work/MyDataset/{img_label}/{rel_src}')shutil.copy(f'./work/{img_src}',f'./work/MyDataset/{img_label}/{rel_src}')
print('图片处理完毕')

整体代码

# 解压上传的数据集压缩包并查看数据集结构
!unzip MyDataset.zip -d data/
import os
print(os.listdir("data"))
print(os.listdir("data/train"))
print(os.listdir('data/train/1'))

输出

['val', 'train']
['9', '10', '5', '8', '11', '2', '7', '3', '0', '1', '6', '4']
['DKkQylbgdrWRjYap63MCJe0UBLhcHXPm.jpg', 'k9HWNaG2Z1wUKAOYdSDu7vRr4xBqmTCV.jpg', 'LfIoOrSNvKHQzsGtm5eMZc0lRuBXhTP6.jpg', '0esFjXNqc5xbMmUaJkRVPwQorWlu3LvA.jpg', 'PMoFIabq0W9U2wETZr7yf4JLYdBxv6hQ.jpg', 'QlUX4zHfPZ3LxRDswqm5FeMbnTWNaj6g.jpg', '7E9oOUcQjkLMvpAtNymHCRSqFfdGVDK4.jpg', 'pznq7EivBH9LwrNysIWxgGeomTlOP8cZ.jpg', '3Ndv9X6uTgzFtnoA01VECIBPj7xqlewG.jpg', 'o1g6adKmS4lBDw2F5buAYnetUWh7xXGz.jpg', '7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg', 'qgimvDE8Zaf4PJ32dkNhwVy5nxATOtrX.jpg', 'RJnThakSOGUzeFBdigXAm2NsL8jyYvu3.jpg', 'Ig61xq3ME78fdCRTDWhaKkcyuOQj24PG.jpg', 'tAdqSefI0DohNuU6wgVyPca7Qz5lYTOH.jpg', 'pTe31kYFqwyOGmV50sbhgoLQ9KcjJaxd.jpg', 'HUuwb4gRqoPWD3Lvrsa9hVcQ7FSfOT8t.jpg', 'cqkJDEpWiwS69UxFtKMPRgb4mXhj1LAs.jpg', '2pjBVbqF30cUTvRIYtsCGfgwPKOJz4ua.jpg', 'e7f1iucltpVQTXroFR96xawm2BDZYnNG.jpg', 'tWcMpXTe78zo2ikhbUOqPud6VJ5RfSEw.jpg', 'gcTxA5NwLztvWr7YPCMnDjdFyfqoa2uK.jpg', 'qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg', 'S4hfUR5kOj7CXr826Gxa9t3bEBPioJq1.jpg', 'A8PMtHzoyFE0WgjpZ2qUYbduL4T9arxN.jpg', '8E0bSi3h1aVy2cNWpgOsKvxZCQtzqkLU.jpg', 'oDc9XxipzfBjUAEl0hOmyd4PNr5v8IsG.jpg', 'AVGJoCPsX3L6I5Y2M94kEO7vNHmt1Ure.jpg', 'SqyF8c0Rak1NedXpYvlI9TsVwzOhtGZJ.jpg', '04Iv3QNKtu2DAfRTgs9XZwBMb1Cm7l6P.jpg', '94NwrzYLo8iMtagTR3SfPGHmWvZXbyUx.jpg', 'oiU4YjnhNpI3JWagx8SuTCktA6qZXGRH.jpg', 'e416wAERYOQ7NutUJDcIVFk2oPWpC3q8.jpg', '9I3enpUrZ5xD2TvRAOFt4S7lBfVMdqsJ.jpg', 'oLNGFnUPmQhxOkdbv37HwSj8uql4z1sp.jpg', 'nHfDoId8SXKzMt1weR7bJlaWPcNx32yv.jpg', 'dzV1Psxncp6H8g4KWhX3mbrTfqwuLaNv.jpg', 'OW6e1GbpNsfmxFvLQKMnIByX4hDcS2io.jpg', 'MOGw0PDqjmnLdViez26b7WY3hU85vatH.jpg', '3R5BWakTdG2hKjJoiNxg0pr61LsDqSuM.jpg', '41OaVziAEuenpKqv7LCYMPsGH6BkQotD.jpg', 'fRbdkW0GDAhBjpTVeonPItycEia45Ns9.jpg', 'RsYG3VJi7NTXptoPQvWKhcFaDqIe4EkC.jpg', 'jWPXtA07yYrcRxNBUkwC9dpuSs2M463e.jpg', 'bqRVATEuI4x3kJSO9DitWjYms8KoG2Lz.jpg', 'OCXPGzodQsZHRnfMFaBkqW9hKYxA4glr.jpg', 'hAPzcCeE04sSadDIFB627ipyOWgjX8mq.jpg', 'PIbktpOd2DqHwVceLzUyE7CmohMjNA91.jpg', 'adAjP14SXL6vVJ95TRrMIYDiHl7BUqbF.jpg', 'hTxYnXrQ35vwKL1NEMSIot02djHy8i6D.jpg', 'kwuLVmg7n9I4iEOzMQC1NxJfvX6Bhoqs.jpg', '78WTn5auMmQshIZi2qDAYdR1oKcwEzfk.jpg', 'hOEPm8o26CBptkD5yT3fsgbM4dRuVZzw.jpg', 'PIpXbRiu68dm2s1DfxHJGAYLegOUMzoB.jpg', '8NxvitwMaCpsuEQT17nDXzFR5gAZ4rfc.jpg', 'JgxcdpvW7f6lKMShPjHFeZ2RDX13UCiN.jpg', 'BhHLRN0QTWOwl3UEG9J17XScni2P5gVe.jpg', '7Gw2o8LJTF4ecZI6nl1WuDrsAOSfQPaq.jpg', 't6xZhQkD2jWCOi1r3fK7T9slGHVbwgNd.jpg', 'gbfyYtlWaAO4iUCPK2cFVkoQX9MmwJTI.jpg', '9PQic6o4VyZ13pLAYu2avFWSbJRz57fC.jpg', 'Cf2Z3j6hYliVOduEvK5NJp4yba0wSGcA.jpg', '2QvYgMIzELXH4Fy8GNDBaPS3W0tVZ5xq.jpg', 'oFXrWl80gRMenqPbG2uZv5wk4KmHjaNd.jpg', 'l1NsjeJKdvFimRgh6IEZuqfxCw5An8o4.jpg', 'tvWgSwN9m3BZ5qOXjK7LexVCIn6F4AHk.jpg', 'ahcTZUbOmJsloVt8vGMjwPqIXd0x9iy1.jpg', 'hp0nNWXar81lB3eSYE9kGcdDL4tJfuwC.jpg', 'HNehmorRIS6M2iDj48gZL9OKva7Xck1n.jpg', '0GX4YKdcwBi15lTpR7ExWO2ZagseoVNI.jpg', 'b57SiGKYPaE1DrfxJtVeQdlOAUojLZhR.jpg', 'f4gLhHjyKdxTumna7F5pGWPqVIRY01et.jpg', 'AHxQ1GFgRLs802diT3VIlwOoqkW9Sar4.jpg', '09i4DcyrWktZb1naHFEpL5elhG3CvYxu.jpg', 'AieqOGKQC7fgDayW9kLuJ5mUHx4XP2I6.jpg', 'MwZIekE7oxPtRpTHQVf4l6qA2iC1zWLh.jpg', '6M8xAZBdQLkuDcF14HEz53J0IboiPfUa.jpg', 'eD4gfaQTFdoWUCnhPj6YmIZBl5AMxNik.jpg', 'PwBJK7rZHDhq46ynYoj9Saxip0IldMV2.jpg', 'neHaTbwPkdVmoOA3JyWxR7Lh92NDpzEf.jpg', 'E38k6xhQFYKALn4tDlwiPBfpdCygeSNs.jpg', 'N1VpzqjoRmPZCQ5KasAv72TwtMDFrG0d.jpg', 'RvXKbfDuF4W6exgVcInE3SktJa9LzBj8.jpg', 'aMoxSymjdiUwbJ6k5NzGR09uILQ4sEc3.jpg', 'l1WfIcvOZk9jAn2xQwtSCEgY5XhRyoFB.jpg', '0IWfLUGk53iHt9SElNzKsBCDwuMjpPbR.jpg', '80vskwDtCRAz9iWYjnrhIGfeXdUZx16b.jpg', '0WglfKCD5Gu2LqI1msTSZa6orBO7XAz9.jpg', 'os5kaDubPM1hY7f6gRrSOZqNQFEAU89v.jpg', 'm2azqs0NGPDdjR8rUTxWF4covLE5ikQZ.jpg', 'siKAzUrV9eykjlCQ0odZEhnW7FIgTuLm.jpg', 'jZWaf0ne2R1pDo9hTBkCbA8YOq3LlQ5x.jpg', 'Kny8zFiIt4vxNSO6g7Lu9kGfdVJoqPC3.jpg', 'LTMkHx9w2nfsRiZec3bEVtmujpv7qS1y.jpg', 'oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg', 'oJ4HWQkZDvta8rUyinRu9fVNs3BX1Kj7.jpg', 'mUp6082yMQghXY13OtvxabTrNSEeiu4B.jpg', 'RZWpn9jGxcKSUb3Y56fVMQHlJhEIeNiA.jpg', 'L86JQlekn9Ko01TbXHYMFImdZ2upxg5h.jpg', 'I0YcgXB97QL6MtHlU3p1znqWdCGOD4mo.jpg', 'IG9NFCfMybKYiQhquOd5H3DjwlSakW6U.jpg', 'SZYosxl4cHRWyT3h5JFqNjGdnIag6907.jpg', '4dMVtGvRJbrjK6X17STZ9Lx3kgeEioqp.jpg', 'q9YrDFK73Jfv15SHpTWelGAIwnBxt0iE.jpg', 'spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg', 'p1ji352o6vhd8l0Q4uNVRZrIgkSLnfBq.jpg', '7WQ0ByMPtJAdZ8h5OkveLi3ScuU6bIY1.jpg', 'hajCi0GDlVP2ONg6FeSWrvubQ34ozwkx.jpg', 'RKLDkUwmFg0Oj5tPeIs31y7J8AQZ9dni.jpg', 'hWOAp1EV6nJzYxaHt03T8GPNe7ujUiF9.jpg', 'B0a2VHnwQv1byMDTlEiOJXxI7Scs9Zjf.jpg', 'gMzOoyTrGniBj1vxN0AeD9VQsFHU7aKR.jpg', 'CiBq0GVawv1rdYyLDjcWoIXP6SKbzH8F.jpg', 'G71cYNEBD6shJLkgVzwb52m8oRluKUHS.jpg', '7IdLnFCb3a25cKNV6tXuYi1fe0hJQMOU.jpg', 'jsThJuVYQxUKSz3btXdA5q8M1O9Cioaw.jpg', '2OpyK1cm85obujwEMqGWNv9V7PnQfJ3U.jpg', 'DLIZr5TjPepd8csioJXMbYHk94RmKx6v.jpg', 'Q80DEFkGlxJj2qR37t4ZKpY6zMdvuIyr.jpg', 'EFxXsVJ0qHkomcBhnLfY96W5U4yOliQG.jpg', 'okw9N05dAnsxgW2IuQy7eGhz1iLOqVrJ.jpg', 'obRL95fxtP1uCNBwQiTjWsdqUvgp2Z43.jpg', '6jTZ5sfCpGwJWIK3DaYQvixLbNt48nHr.jpg', 'gLqBoG3ah0AHXIYWS7dFTt6pxDw8snQv.jpg', '1lzs3kM8NiILvcgDYtn6fdCoSeXauJ5P.jpg', 'FEuyDnKSIJ1a6UtY5LB2rGpRqOm3xP9Z.jpg', 'OKvn2uJmWQi4R9Cs8B7fxbkZtoczq31Y.jpg', 'ByYKkZHb8omRPcvfe5GzXsxOQ3DlLuUq.jpg', 'SKoiaj8C3UGyvJQXh5zWwrxNmYkdEHqn.jpg', '2cKUvXCjm5HNWksY1b4ioIgdSFqyMtEJ.jpg', 'kJK9OA3hXpMWeUY7cifvrz0BItn1VS2T.jpg', 't1DnLxSZXwWTgeJsyE02lrjHfdM35po8.jpg', '5C76eISyb3vmPZuMYcARHU8aFQrBWf1k.jpg', 'puBcg8Fh6tXs27doz1aAIl4L0iVYC3wE.jpg', 'MV5C7YmuzG1LyZplFXvqOQkW4JStjcNP.jpg', 'ruleKNQvzwqmy5sn9MDd7I2RUJjVCWh8.jpg', 'l9Z3gPwjC5HbhINcfVO8dnz1qAxBrJkU.jpg', 'H9BcFOo8UI3jX2CyW0mzxn7agJNAsZQS.jpg', 'jHUJE37YZOGAXInPmyCSp9f0o4uvRe5W.jpg', 'I8jNkAVgZ1yqDw5K9b0Wm4rETfiGBcUF.jpg', 'kKzQrE6GjfpeFhsXx2Ddu9YaTHc3PUbB.jpg', 'jla5O2TkVhefr07XDLpMEonuG6yJWgYd.jpg', 'Km9BZsaSUoxQ4VArcXYyHThIDRbq2t7l.jpg', 'fBp0Yor4EQtWkM7I3TsnNHLXuvCFacjS.jpg']

超参数设置

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 10                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

数据预处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision# 数据集目录路径
data_path_train = "data/train/"
data_path_val = "data/val/"# 创建训练数据集def create_dataset_canidae(dataset_path, usage):"""数据加载"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 数据增强操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_setdataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
print(step_size_train)
print(step_size_val)
data = next(dataset_val.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)

输出

96
24
Tensor of image (18, 3, 224, 224)
Labels: [ 1  2  4  3  0 10  5  4 11  9 11  6  7  1 11  5  1  3]

数据集可视化查看

import matplotlib.pyplot as plt
import numpy as np# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}plt.figure(figsize=(5, 5))
for i in range(4):# 获取图像及其对应的labeldata_image = images[i].asnumpy()data_label = labels[i]# 处理图像供展示使用data_image = np.transpose(data_image, (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])data_image = std * data_image + meandata_image = np.clip(data_image, 0, 1)# 显示图像plt.subplot(2, 2, i+1)plt.imshow(data_image)plt.title(class_name[int(labels[i].asnumpy())])plt.axis("off")plt.show()

在这里插入图片描述

网络结构搭建

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
class ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型# download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

形式一:模型微调

模型训练
from mindspore import nn, train
from mindspore.nn import Loss, Accuracy
!pip install download
import mindspore as ms
from download import download
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = network.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 12)
# 重置全连接层
network.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
network.avg_pool = avg_poolimport mindspore as ms
import mindspore# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=momentum)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# 实例化模型
model = train.Model(network, loss_fn, opt, metrics={"Accuracy": Accuracy()})def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = mindspore.ops.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
# 最佳模型保存路径
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"
import os
import time# 开始循环训练
print("Start Training Loop ...")best_acc = 0for epoch in range(num_epochs):losses = []network.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

输出

Start Training Loop ...
--------------------
Epoch: [  1/ 10], Average Train Loss: [1.774], Accuracy: [0.838]
epoch time: 60892.337 ms, per step time: 634.295 ms
--------------------
Epoch: [  2/ 10], Average Train Loss: [0.762], Accuracy: [0.905]
epoch time: 8745.406 ms, per step time: 91.098 ms
--------------------
Epoch: [  3/ 10], Average Train Loss: [0.568], Accuracy: [0.921]
epoch time: 8449.129 ms, per step time: 88.012 ms
--------------------
Epoch: [  4/ 10], Average Train Loss: [0.508], Accuracy: [0.910]
epoch time: 8199.763 ms, per step time: 85.414 ms
--------------------
Epoch: [  5/ 10], Average Train Loss: [0.459], Accuracy: [0.900]
epoch time: 7856.060 ms, per step time: 81.834 ms
--------------------
Epoch: [  6/ 10], Average Train Loss: [0.405], Accuracy: [0.931]
epoch time: 8138.927 ms, per step time: 84.780 ms
--------------------
Epoch: [  7/ 10], Average Train Loss: [0.368], Accuracy: [0.919]
epoch time: 8333.523 ms, per step time: 86.808 ms
--------------------
Epoch: [  8/ 10], Average Train Loss: [0.354], Accuracy: [0.912]
epoch time: 8271.008 ms, per step time: 86.156 ms
--------------------
Epoch: [  9/ 10], Average Train Loss: [0.338], Accuracy: [0.928]
epoch time: 8457.969 ms, per step time: 88.104 ms
--------------------
Epoch: [ 10/ 10], Average Train Loss: [0.338], Accuracy: [0.907]
epoch time: 8183.743 ms, per step time: 85.247 ms
================================================================================
End of validation the best Accuracy is:  0.931, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt
模型评估
import matplotlib.pyplot as plt
import mindspore as msdef visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为分类数12head = nn.Dense(in_channels, 12)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)#print(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()print(type(images))print(images.shape)#print(images)labels = data["label"].asnumpy()#print(labels)class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}# 预测图像类别data_pre=ms.Tensor(data["image"])print(data_pre.shape)print(type(data_pre))output = model.predict(data_pre)#print(output)pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()
visualize_model('BestCheckpoint/resnet50-best.ckpt', dataset_val)

输出
在这里插入图片描述

形式二:固定特征训练

模型训练
net_work = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为分类数12
head = nn.Dense(in_channels, 12)
# 重置全连接层
net_work.fc = head
# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool
# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False
# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters)
def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss
# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": Accuracy()})
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
num_epochs = 10
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
# 开始循环训练
print("Start Training Loop ...")
best_acc = 0
for epoch in range(num_epochs):losses = []net_work.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model1.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(net_work, best_ckpt_path)
print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)
模型评估
visualize_model(best_ckpt_path, dataset_val)

这篇关于基于Mindspore,通过Resnet50迁移学习实现猫十二分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/783860

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

Java实现Excel与HTML互转

《Java实现Excel与HTML互转》Excel是一种电子表格格式,而HTM则是一种用于创建网页的标记语言,虽然两者在用途上存在差异,但有时我们需要将数据从一种格式转换为另一种格式,下面我们就来看看... Excel是一种电子表格格式,广泛用于数据处理和分析,而HTM则是一种用于创建网页的标记语言。虽然两

Java中Springboot集成Kafka实现消息发送和接收功能

《Java中Springboot集成Kafka实现消息发送和接收功能》Kafka是一个高吞吐量的分布式发布-订阅消息系统,主要用于处理大规模数据流,它由生产者、消费者、主题、分区和代理等组件构成,Ka... 目录一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者一、Kaf

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

windos server2022里的DFS配置的实现

《windosserver2022里的DFS配置的实现》DFS是WindowsServer操作系统提供的一种功能,用于在多台服务器上集中管理共享文件夹和文件的分布式存储解决方案,本文就来介绍一下wi... 目录什么是DFS?优势:应用场景:DFS配置步骤什么是DFS?DFS指的是分布式文件系统(Distr

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一