基于Mindspore,通过Resnet50迁移学习实现猫十二分类

2024-03-07 14:44

本文主要是介绍基于Mindspore,通过Resnet50迁移学习实现猫十二分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用平台介绍

使用平台:启智AI协作平台
使用数据集:百度猫十二分类

数据集介绍

有cat_12_train和cat_12_test和train_list.txt
train_list.txt内有每张图片所对应的标签

Minspore部分操作科普

数据集加载

Mindspore加载图片数据集就直接调整成这种格式就行,然后可以用这个函数加载,自动生成两个列,一列是图片,一列是标签;ImageFolderDataset函数会自动读取和处理数据集,标签就是文件夹的名称
在这里插入图片描述
在这里插入图片描述

数据处理

map函数里可以一键进行处理和映射,定义好数据处理函数,直接把路径和标签Map处理,后面可以带上是训练集还是测试集的标签;,要处理图片就指定 input_columns 参数为image,这个是前面数据读取形成的;
本项目这里用的是ImageFolderDataset,可以自动生成图片对应数据的两个列,要是相对数据处理就设置为数据列名即可
在这里插入图片描述

数据批处理和重复

batch就是批处理,把数据分成指定数量的一个个批次
在这里插入图片描述
最后repeat对数据进行重复
在这里插入图片描述

整体数据处理流程

就是读取数据形成数据和标签对应的列(读取数据函数有很多),然后定义数据预处理函数,在map函数里一键映射,指定要处理的列一键处理,最后对数据进行批次划分,就拿到可以放进训练网络函数的规范数据集了。当然使用时候还要用create_tuple_iterator或者create_dict_iterator函数形成可迭代的数据集。

迁移学习

在迁移学习中,固定特征训练和模型微调都是常用的技术

固定特征训练

在源任务上训练一个模型,并将其应用到目标任务上。在这个过程中,模型的特征提取器是固定的,只对输出层进行调整。这种方法可以利用源任务中已经学习到的特征,从而减少目标任务的训练时间和数据需求。固定特征训练通常适用于目标任务和源任务具有相似的特征空间,并且目标任务的数据量较小的情况。

模型微调

使用预训练模型作为初始模型,并在目标任务的数据集上进行进步训练。在微调阶段,可以根据目标任务的数据和特定要求调整模型的参数使其适应目标任务。模型微调的主要目的是通过在目标任务上的有限训练来调整预训练模型,以取得更好的性能。

举例

以训练一个猫、狗分类器为例,固定特征训练是指在一个大型的猫狗图片数据集上训练一个通用的图像识别模型,然后将该模型应用于特定的猫、狗分类任务。在这个过程中,我们只需要调整模型的输出层,使其能够正确地对猫和狗进行分类。而模型微调则是指使用一个已经在大型数据集上训练好的通用图像识别模型,然后在特定的猫、狗图片数据集上进行进一步训练,以优化模型的性能。在这个过程中我们可以根据猫、狗图片的特点来调整模型的参数,使其能够更好地识别猫和狗。

数据处理过程

由于本项目采用的数据集是百度所提供的猫十二分类,要使用ImageFolderDataset的话,形式不太匹配

现有形式

cat_12_train和cat_12_test里面都是一张一张的图片
train_list.txt内有每张图片所对应的标签
在这里插入图片描述

处理后形式

train和val文件夹内分别有十二个子文件夹,代表12类猫,每个子文件夹内又有一张张的图片
在这里插入图片描述在这里插入图片描述

处理代码

这里有相关代码进行自动划分,但是对于训练集和测试集的划分,我直接采用了手动操作,也可以用代码来实现的;

# 处理异常图片
dir_lit = os.listdir('./work/cat_12_train/')
# dir_lit为一个列表,里面是一张张图片的名称
for list in dir_lit:
# list是图片名称,这里的操作是把这个图片形成一个个的路径img_path=os.path.join('./work/cat_12_train/',list)print(img_path)# 如果不是RGB那就转换为RGBimg=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)dir_lit = os.listdir('./work/cat_12_test/')
for list in dir_lit:img_path=os.path.join('./work/cat_12_test/',list)img=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)
# 整理数据格式
# 创建12个文件夹分别对应标签
path='./work/MyDataset/'
for i in range(12):if not os.path.exists(path+str(i)):os.mkdir(path+str(i))else:continue
#读取每一行
with open(f'./work/train_list.txt','r')as f:img_path=f.readlines()print(img_path)# 里面是一个个的'cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg\t0\n'# 把对应文件放到对应标签文件夹下
for img in img_path:# 拿取每一张图片路径# print(img)img_src= img.split('\t')[0]# img_src为一个个图片路径rel_src= img_src.split('cat_12_train/')[1]# rel_src为图片名称img_label = img.split('\t')[1]img_label = img_label.split('\n')[0]# img_label为图片标签print(img_src)print(rel_src)print(img_label)# os.system(f'cp ./work/{img_src} ./work/MyDataset/{img_label}/{rel_src}')shutil.copy(f'./work/{img_src}',f'./work/MyDataset/{img_label}/{rel_src}')
print('图片处理完毕')

整体代码

# 解压上传的数据集压缩包并查看数据集结构
!unzip MyDataset.zip -d data/
import os
print(os.listdir("data"))
print(os.listdir("data/train"))
print(os.listdir('data/train/1'))

输出

['val', 'train']
['9', '10', '5', '8', '11', '2', '7', '3', '0', '1', '6', '4']
['DKkQylbgdrWRjYap63MCJe0UBLhcHXPm.jpg', 'k9HWNaG2Z1wUKAOYdSDu7vRr4xBqmTCV.jpg', 'LfIoOrSNvKHQzsGtm5eMZc0lRuBXhTP6.jpg', '0esFjXNqc5xbMmUaJkRVPwQorWlu3LvA.jpg', 'PMoFIabq0W9U2wETZr7yf4JLYdBxv6hQ.jpg', 'QlUX4zHfPZ3LxRDswqm5FeMbnTWNaj6g.jpg', '7E9oOUcQjkLMvpAtNymHCRSqFfdGVDK4.jpg', 'pznq7EivBH9LwrNysIWxgGeomTlOP8cZ.jpg', '3Ndv9X6uTgzFtnoA01VECIBPj7xqlewG.jpg', 'o1g6adKmS4lBDw2F5buAYnetUWh7xXGz.jpg', '7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg', 'qgimvDE8Zaf4PJ32dkNhwVy5nxATOtrX.jpg', 'RJnThakSOGUzeFBdigXAm2NsL8jyYvu3.jpg', 'Ig61xq3ME78fdCRTDWhaKkcyuOQj24PG.jpg', 'tAdqSefI0DohNuU6wgVyPca7Qz5lYTOH.jpg', 'pTe31kYFqwyOGmV50sbhgoLQ9KcjJaxd.jpg', 'HUuwb4gRqoPWD3Lvrsa9hVcQ7FSfOT8t.jpg', 'cqkJDEpWiwS69UxFtKMPRgb4mXhj1LAs.jpg', '2pjBVbqF30cUTvRIYtsCGfgwPKOJz4ua.jpg', 'e7f1iucltpVQTXroFR96xawm2BDZYnNG.jpg', 'tWcMpXTe78zo2ikhbUOqPud6VJ5RfSEw.jpg', 'gcTxA5NwLztvWr7YPCMnDjdFyfqoa2uK.jpg', 'qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg', 'S4hfUR5kOj7CXr826Gxa9t3bEBPioJq1.jpg', 'A8PMtHzoyFE0WgjpZ2qUYbduL4T9arxN.jpg', '8E0bSi3h1aVy2cNWpgOsKvxZCQtzqkLU.jpg', 'oDc9XxipzfBjUAEl0hOmyd4PNr5v8IsG.jpg', 'AVGJoCPsX3L6I5Y2M94kEO7vNHmt1Ure.jpg', 'SqyF8c0Rak1NedXpYvlI9TsVwzOhtGZJ.jpg', '04Iv3QNKtu2DAfRTgs9XZwBMb1Cm7l6P.jpg', '94NwrzYLo8iMtagTR3SfPGHmWvZXbyUx.jpg', 'oiU4YjnhNpI3JWagx8SuTCktA6qZXGRH.jpg', 'e416wAERYOQ7NutUJDcIVFk2oPWpC3q8.jpg', '9I3enpUrZ5xD2TvRAOFt4S7lBfVMdqsJ.jpg', 'oLNGFnUPmQhxOkdbv37HwSj8uql4z1sp.jpg', 'nHfDoId8SXKzMt1weR7bJlaWPcNx32yv.jpg', 'dzV1Psxncp6H8g4KWhX3mbrTfqwuLaNv.jpg', 'OW6e1GbpNsfmxFvLQKMnIByX4hDcS2io.jpg', 'MOGw0PDqjmnLdViez26b7WY3hU85vatH.jpg', '3R5BWakTdG2hKjJoiNxg0pr61LsDqSuM.jpg', '41OaVziAEuenpKqv7LCYMPsGH6BkQotD.jpg', 'fRbdkW0GDAhBjpTVeonPItycEia45Ns9.jpg', 'RsYG3VJi7NTXptoPQvWKhcFaDqIe4EkC.jpg', 'jWPXtA07yYrcRxNBUkwC9dpuSs2M463e.jpg', 'bqRVATEuI4x3kJSO9DitWjYms8KoG2Lz.jpg', 'OCXPGzodQsZHRnfMFaBkqW9hKYxA4glr.jpg', 'hAPzcCeE04sSadDIFB627ipyOWgjX8mq.jpg', 'PIbktpOd2DqHwVceLzUyE7CmohMjNA91.jpg', 'adAjP14SXL6vVJ95TRrMIYDiHl7BUqbF.jpg', 'hTxYnXrQ35vwKL1NEMSIot02djHy8i6D.jpg', 'kwuLVmg7n9I4iEOzMQC1NxJfvX6Bhoqs.jpg', '78WTn5auMmQshIZi2qDAYdR1oKcwEzfk.jpg', 'hOEPm8o26CBptkD5yT3fsgbM4dRuVZzw.jpg', 'PIpXbRiu68dm2s1DfxHJGAYLegOUMzoB.jpg', '8NxvitwMaCpsuEQT17nDXzFR5gAZ4rfc.jpg', 'JgxcdpvW7f6lKMShPjHFeZ2RDX13UCiN.jpg', 'BhHLRN0QTWOwl3UEG9J17XScni2P5gVe.jpg', '7Gw2o8LJTF4ecZI6nl1WuDrsAOSfQPaq.jpg', 't6xZhQkD2jWCOi1r3fK7T9slGHVbwgNd.jpg', 'gbfyYtlWaAO4iUCPK2cFVkoQX9MmwJTI.jpg', '9PQic6o4VyZ13pLAYu2avFWSbJRz57fC.jpg', 'Cf2Z3j6hYliVOduEvK5NJp4yba0wSGcA.jpg', '2QvYgMIzELXH4Fy8GNDBaPS3W0tVZ5xq.jpg', 'oFXrWl80gRMenqPbG2uZv5wk4KmHjaNd.jpg', 'l1NsjeJKdvFimRgh6IEZuqfxCw5An8o4.jpg', 'tvWgSwN9m3BZ5qOXjK7LexVCIn6F4AHk.jpg', 'ahcTZUbOmJsloVt8vGMjwPqIXd0x9iy1.jpg', 'hp0nNWXar81lB3eSYE9kGcdDL4tJfuwC.jpg', 'HNehmorRIS6M2iDj48gZL9OKva7Xck1n.jpg', '0GX4YKdcwBi15lTpR7ExWO2ZagseoVNI.jpg', 'b57SiGKYPaE1DrfxJtVeQdlOAUojLZhR.jpg', 'f4gLhHjyKdxTumna7F5pGWPqVIRY01et.jpg', 'AHxQ1GFgRLs802diT3VIlwOoqkW9Sar4.jpg', '09i4DcyrWktZb1naHFEpL5elhG3CvYxu.jpg', 'AieqOGKQC7fgDayW9kLuJ5mUHx4XP2I6.jpg', 'MwZIekE7oxPtRpTHQVf4l6qA2iC1zWLh.jpg', '6M8xAZBdQLkuDcF14HEz53J0IboiPfUa.jpg', 'eD4gfaQTFdoWUCnhPj6YmIZBl5AMxNik.jpg', 'PwBJK7rZHDhq46ynYoj9Saxip0IldMV2.jpg', 'neHaTbwPkdVmoOA3JyWxR7Lh92NDpzEf.jpg', 'E38k6xhQFYKALn4tDlwiPBfpdCygeSNs.jpg', 'N1VpzqjoRmPZCQ5KasAv72TwtMDFrG0d.jpg', 'RvXKbfDuF4W6exgVcInE3SktJa9LzBj8.jpg', 'aMoxSymjdiUwbJ6k5NzGR09uILQ4sEc3.jpg', 'l1WfIcvOZk9jAn2xQwtSCEgY5XhRyoFB.jpg', '0IWfLUGk53iHt9SElNzKsBCDwuMjpPbR.jpg', '80vskwDtCRAz9iWYjnrhIGfeXdUZx16b.jpg', '0WglfKCD5Gu2LqI1msTSZa6orBO7XAz9.jpg', 'os5kaDubPM1hY7f6gRrSOZqNQFEAU89v.jpg', 'm2azqs0NGPDdjR8rUTxWF4covLE5ikQZ.jpg', 'siKAzUrV9eykjlCQ0odZEhnW7FIgTuLm.jpg', 'jZWaf0ne2R1pDo9hTBkCbA8YOq3LlQ5x.jpg', 'Kny8zFiIt4vxNSO6g7Lu9kGfdVJoqPC3.jpg', 'LTMkHx9w2nfsRiZec3bEVtmujpv7qS1y.jpg', 'oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg', 'oJ4HWQkZDvta8rUyinRu9fVNs3BX1Kj7.jpg', 'mUp6082yMQghXY13OtvxabTrNSEeiu4B.jpg', 'RZWpn9jGxcKSUb3Y56fVMQHlJhEIeNiA.jpg', 'L86JQlekn9Ko01TbXHYMFImdZ2upxg5h.jpg', 'I0YcgXB97QL6MtHlU3p1znqWdCGOD4mo.jpg', 'IG9NFCfMybKYiQhquOd5H3DjwlSakW6U.jpg', 'SZYosxl4cHRWyT3h5JFqNjGdnIag6907.jpg', '4dMVtGvRJbrjK6X17STZ9Lx3kgeEioqp.jpg', 'q9YrDFK73Jfv15SHpTWelGAIwnBxt0iE.jpg', 'spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg', 'p1ji352o6vhd8l0Q4uNVRZrIgkSLnfBq.jpg', '7WQ0ByMPtJAdZ8h5OkveLi3ScuU6bIY1.jpg', 'hajCi0GDlVP2ONg6FeSWrvubQ34ozwkx.jpg', 'RKLDkUwmFg0Oj5tPeIs31y7J8AQZ9dni.jpg', 'hWOAp1EV6nJzYxaHt03T8GPNe7ujUiF9.jpg', 'B0a2VHnwQv1byMDTlEiOJXxI7Scs9Zjf.jpg', 'gMzOoyTrGniBj1vxN0AeD9VQsFHU7aKR.jpg', 'CiBq0GVawv1rdYyLDjcWoIXP6SKbzH8F.jpg', 'G71cYNEBD6shJLkgVzwb52m8oRluKUHS.jpg', '7IdLnFCb3a25cKNV6tXuYi1fe0hJQMOU.jpg', 'jsThJuVYQxUKSz3btXdA5q8M1O9Cioaw.jpg', '2OpyK1cm85obujwEMqGWNv9V7PnQfJ3U.jpg', 'DLIZr5TjPepd8csioJXMbYHk94RmKx6v.jpg', 'Q80DEFkGlxJj2qR37t4ZKpY6zMdvuIyr.jpg', 'EFxXsVJ0qHkomcBhnLfY96W5U4yOliQG.jpg', 'okw9N05dAnsxgW2IuQy7eGhz1iLOqVrJ.jpg', 'obRL95fxtP1uCNBwQiTjWsdqUvgp2Z43.jpg', '6jTZ5sfCpGwJWIK3DaYQvixLbNt48nHr.jpg', 'gLqBoG3ah0AHXIYWS7dFTt6pxDw8snQv.jpg', '1lzs3kM8NiILvcgDYtn6fdCoSeXauJ5P.jpg', 'FEuyDnKSIJ1a6UtY5LB2rGpRqOm3xP9Z.jpg', 'OKvn2uJmWQi4R9Cs8B7fxbkZtoczq31Y.jpg', 'ByYKkZHb8omRPcvfe5GzXsxOQ3DlLuUq.jpg', 'SKoiaj8C3UGyvJQXh5zWwrxNmYkdEHqn.jpg', '2cKUvXCjm5HNWksY1b4ioIgdSFqyMtEJ.jpg', 'kJK9OA3hXpMWeUY7cifvrz0BItn1VS2T.jpg', 't1DnLxSZXwWTgeJsyE02lrjHfdM35po8.jpg', '5C76eISyb3vmPZuMYcARHU8aFQrBWf1k.jpg', 'puBcg8Fh6tXs27doz1aAIl4L0iVYC3wE.jpg', 'MV5C7YmuzG1LyZplFXvqOQkW4JStjcNP.jpg', 'ruleKNQvzwqmy5sn9MDd7I2RUJjVCWh8.jpg', 'l9Z3gPwjC5HbhINcfVO8dnz1qAxBrJkU.jpg', 'H9BcFOo8UI3jX2CyW0mzxn7agJNAsZQS.jpg', 'jHUJE37YZOGAXInPmyCSp9f0o4uvRe5W.jpg', 'I8jNkAVgZ1yqDw5K9b0Wm4rETfiGBcUF.jpg', 'kKzQrE6GjfpeFhsXx2Ddu9YaTHc3PUbB.jpg', 'jla5O2TkVhefr07XDLpMEonuG6yJWgYd.jpg', 'Km9BZsaSUoxQ4VArcXYyHThIDRbq2t7l.jpg', 'fBp0Yor4EQtWkM7I3TsnNHLXuvCFacjS.jpg']

超参数设置

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 10                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

数据预处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision# 数据集目录路径
data_path_train = "data/train/"
data_path_val = "data/val/"# 创建训练数据集def create_dataset_canidae(dataset_path, usage):"""数据加载"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 数据增强操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_setdataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
print(step_size_train)
print(step_size_val)
data = next(dataset_val.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)

输出

96
24
Tensor of image (18, 3, 224, 224)
Labels: [ 1  2  4  3  0 10  5  4 11  9 11  6  7  1 11  5  1  3]

数据集可视化查看

import matplotlib.pyplot as plt
import numpy as np# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}plt.figure(figsize=(5, 5))
for i in range(4):# 获取图像及其对应的labeldata_image = images[i].asnumpy()data_label = labels[i]# 处理图像供展示使用data_image = np.transpose(data_image, (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])data_image = std * data_image + meandata_image = np.clip(data_image, 0, 1)# 显示图像plt.subplot(2, 2, i+1)plt.imshow(data_image)plt.title(class_name[int(labels[i].asnumpy())])plt.axis("off")plt.show()

在这里插入图片描述

网络结构搭建

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
class ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型# download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

形式一:模型微调

模型训练
from mindspore import nn, train
from mindspore.nn import Loss, Accuracy
!pip install download
import mindspore as ms
from download import download
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = network.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 12)
# 重置全连接层
network.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
network.avg_pool = avg_poolimport mindspore as ms
import mindspore# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=momentum)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# 实例化模型
model = train.Model(network, loss_fn, opt, metrics={"Accuracy": Accuracy()})def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = mindspore.ops.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
# 最佳模型保存路径
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"
import os
import time# 开始循环训练
print("Start Training Loop ...")best_acc = 0for epoch in range(num_epochs):losses = []network.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

输出

Start Training Loop ...
--------------------
Epoch: [  1/ 10], Average Train Loss: [1.774], Accuracy: [0.838]
epoch time: 60892.337 ms, per step time: 634.295 ms
--------------------
Epoch: [  2/ 10], Average Train Loss: [0.762], Accuracy: [0.905]
epoch time: 8745.406 ms, per step time: 91.098 ms
--------------------
Epoch: [  3/ 10], Average Train Loss: [0.568], Accuracy: [0.921]
epoch time: 8449.129 ms, per step time: 88.012 ms
--------------------
Epoch: [  4/ 10], Average Train Loss: [0.508], Accuracy: [0.910]
epoch time: 8199.763 ms, per step time: 85.414 ms
--------------------
Epoch: [  5/ 10], Average Train Loss: [0.459], Accuracy: [0.900]
epoch time: 7856.060 ms, per step time: 81.834 ms
--------------------
Epoch: [  6/ 10], Average Train Loss: [0.405], Accuracy: [0.931]
epoch time: 8138.927 ms, per step time: 84.780 ms
--------------------
Epoch: [  7/ 10], Average Train Loss: [0.368], Accuracy: [0.919]
epoch time: 8333.523 ms, per step time: 86.808 ms
--------------------
Epoch: [  8/ 10], Average Train Loss: [0.354], Accuracy: [0.912]
epoch time: 8271.008 ms, per step time: 86.156 ms
--------------------
Epoch: [  9/ 10], Average Train Loss: [0.338], Accuracy: [0.928]
epoch time: 8457.969 ms, per step time: 88.104 ms
--------------------
Epoch: [ 10/ 10], Average Train Loss: [0.338], Accuracy: [0.907]
epoch time: 8183.743 ms, per step time: 85.247 ms
================================================================================
End of validation the best Accuracy is:  0.931, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt
模型评估
import matplotlib.pyplot as plt
import mindspore as msdef visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为分类数12head = nn.Dense(in_channels, 12)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)#print(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()print(type(images))print(images.shape)#print(images)labels = data["label"].asnumpy()#print(labels)class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}# 预测图像类别data_pre=ms.Tensor(data["image"])print(data_pre.shape)print(type(data_pre))output = model.predict(data_pre)#print(output)pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()
visualize_model('BestCheckpoint/resnet50-best.ckpt', dataset_val)

输出
在这里插入图片描述

形式二:固定特征训练

模型训练
net_work = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为分类数12
head = nn.Dense(in_channels, 12)
# 重置全连接层
net_work.fc = head
# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool
# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False
# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters)
def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss
# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": Accuracy()})
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
num_epochs = 10
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
# 开始循环训练
print("Start Training Loop ...")
best_acc = 0
for epoch in range(num_epochs):losses = []net_work.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model1.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(net_work, best_ckpt_path)
print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)
模型评估
visualize_model(best_ckpt_path, dataset_val)

这篇关于基于Mindspore,通过Resnet50迁移学习实现猫十二分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/783860

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time