第十二章 迁移学习-实战宝可梦精灵

2024-08-22 08:52

本文主要是介绍第十二章 迁移学习-实战宝可梦精灵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、Pokemon数据集
    • 1.1 数据集收集
    • 1.2 数据集划分
    • 1.3 数据集加载
    • 1.4 数据预处理
    • 1.5 pytorch自定义数据库实现
  • 二、ResNet网络搭建
  • 三、训练与测试
  • 四、迁移学习
    • 4.1 pytorch实现迁移学习

一、Pokemon数据集

1.1 数据集收集

在这里插入图片描述

# git下载
git lfs install
git clone https://www.modelscope.cn/datasets/ModelBulider/pokemon.git

1.2 数据集划分

在这里插入图片描述


1.3 数据集加载

在这里插入图片描述

  • 加载数据
    ① 继承 torch.utils.data.Dataset
    ② 实现 __len__ 函数,其返回数据集的数量(整型数字)
    ③ 实现 __getitem__函数,根据索引值返回一个数据
    在这里插入图片描述

举例:
在这里插入图片描述


1.4 数据预处理

将尺寸大小不一致的数据(图片)预处理为大小一致的1数据
② 数据增强(旋转、裁剪等)
③ 归一化(均值、方差)
④ 转换为 Tensor 数据类型
在这里插入图片描述


1.5 pytorch自定义数据库实现

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: code - pokemon.py
@author: yonghao
@Description: 
@since 2021/03/01 19:41
'''
from visdom import Visdom
import time
import torch
import os, glob
import random, csv
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoaderroot = 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon'class Pokemon(Dataset):
def __init__(self, root, resize, mode='train'):
'''
初始化数据集
:paramroot: 图片存储的位置
:paramresize: 重新编辑图片的尺寸
:parammode: 初始化图片的类型(可以是数据集中各中分类)
'''
super(Pokemon, self).__init__()
self.root = rootself.resize = resizeself.mode = modeself.name2label = {}
# 创建 类名-> label 的映射字典
# os.listdir()每次顺序都不一样,故使用sorted()排序,使 类名-> label 的映射字典固定
for name in sorted(os.listdir(os.path.join(root))):
# 只读取文件夹名
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label)
self.images, self.labels = self.load_csv('images.csv')
# 根据mode设定数据集的比例
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # 20%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]def __len__(self):
return len(self.images)def __getitem__(self, item) -> tuple:
# item ~ [0,len(images)-1]
# self.images , self.labels
# image , label
img, label = self.images[item], self.labels[item]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(1.25 * self.resize), int(1.25 * self.resize))), # 调整尺寸
transforms.RandomRotation(15), # 旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# 注意transforms.Normalize() 应该在transforms.ToTensor() 后面
# 数据在通道层上归一化,会使变化图片的像素
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 返回由img,label 组成的Tensor 元组
img = tf(img)
label = torch.tensor(label)return img, labeldef denormalize(self, x_het):
'''
图像逆正则化显示
:paramx_het: 正则化后的数据
:return:
'''# x_het = (x - mean) / std
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
# x = x_het * std + mean
# x:[channel , h , w] , mean:[3] -> [3,1,1] , std:[3] -> [3,1,1]
mean = mean.unsqueeze(dim=-1).unsqueeze(dim=-1)std = std.unsqueeze(dim=-1).unsqueeze(dim=-1)x = x_het * std + meanreturn xdef load_csv(self, filename):
'''
加载图片数据 与 其label数据
:paramfilename: 加载数据的文件名
:return:
'''
# 仅在第一次调用时创建csv文件,保存 图片路径——>label 的映射关系
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
'''
python在模块glob中定义了glob()函数,实现了对目录内容进行匹配的功能,
glob.glob()函数接受通配模式作为输入,并返回所有匹配的文件名和路径名列表
与os.listdir类似
'''
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 1167 , 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon\\bulbasaur\\00000000.png'# 打乱的是图片的存储路径
random.shuffle(images)# 使用上下文管理,对文件进行操作
'''
with是从Python2.5引入的一个新的语法,它是一种上下文管理协议,目的在于从流程图中把try,except 和finally 关键字和资源分配释放相关代码统统去掉,简化try….except….finlally的处理流程。with通过__enter__方法初始化,然后在__exit__中做善后以及处理异常。所以使用with处理的对象必须有__enter__()和__exit__()这两个方法。其中__enter__()方法在语句体(with语句包裹起来的代码块)执行之前进入运行,__exit__()方法在语句体执行完毕退出后运行。with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。紧跟with后面的语句会被求值,返回对象的__enter__()方法被调用,这个方法的返回值将被赋值给as关键字后面的变量,当with后面的代码块全部被执行完之后,将调用前面返回对象的__exit__()方法
'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
# os.sep 为系统自动识别的文件路径分隔符
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])images, labels = [], []
with open(os.path.join(root, filename), mode='r') as f:
reader = csv.reader(f)
for row in reader:
img, label = rowimages.append(img)
labels.append(int(label))assert len(images) == len(labels)return images, labelsdef main():
vis = Visdom()
# 获取数据集(单个数据做返回)
db = Pokemon(root, 64, mode='train')
img, label = next(iter(db))
print('sample:', img.shape, label.shape)
vis.image(img, win='img_win_het', opts=dict(title='norm_img_show'))
vis.image(db.denormalize(img), win='img_win', opts=dict(title='img_show'))# 批量导出数据
loader = DataLoader(db, batch_size=32, shuffle=True)
for x, y in loader:
vis.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
vis.text(str(y.numpy()), win='label', opts=dict(title='bacth-y'))
time.sleep(10)if __name__ == '__main__':
main()

二、ResNet网络搭建

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: 实战代码- resnet.py
@author: yonghao
@Description: 创建残差网络结构
@since 2021/03/01 17:51
'''
import torch
import torch.nn.functional as F
from torch import nn
import utilsclass ResBlk(nn.Module):
'''
创建ResBlock
'''def __init__(self, ch_in, ch_out, stride=1):
'''
创建ResBlock模块
:paramch_in: 输入的通道数
:paramch_out: 输出的通道数
:paramstride: 卷积步长
'''
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
if ch_in == ch_out:
self.extra = nn.Sequential()
else:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))out = out + self.extra(x)
out = F.relu(out)
return outclass ResNet18(nn.Module):def __init__(self, num_class):
'''
创建18层的ResNet
:paramnum_class:分类数量
'''
super(ResNet18, self).__init__()self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=2),
nn.BatchNorm2d(16)
)# followed 4 blocks
# [b , 16 , h , w] => [b , 32 , h , w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b , 32 , h , w] => [b , 64 , h , w]
self.blk2 = ResBlk(32, 64, stride=3)
# [b , 64 , h , w] => [b , 128 , h , w]
self.blk3 = ResBlk(64, 128, stride=2)
# [b , 128 , h , w] => [b , 256 , h , w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b , 256 , h , 2] => [b , 256*h*w]
self.flat = utils.Flatten()
# [b , 256*h*w] => [b , num_class]
self.out_layer = nn.Linear(256 * 3 * 3, num_class)def forward(self, x):
x = F.relu(self.conv1(x), inplace=True)
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = self.flat(x)
out = self.out_layer(x)
return outdef mian():
# 测试ResBlk,当ch_in==ch_out时正确
# 当ch_in==ch_out时报异常
blk = ResBlk(64, 128, stride=2)
tmp = torch.randn(2, 64, 64, 64)
out = blk(tmp)
print('block:', out.shape)model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print("resnet:", out.shape)
p = sum([i.numel() for i in model.parameters()])
print('parameters size:', p)if __name__ == '__main__':
mian()

三、训练与测试

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: project_code - process.py
@author: yonghao
@Description: 实现训练过程 与 测试过程
@since 2021/03/02 18:54
'''
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from model.resnet import ResNet18
from pokemon import Pokemon# 批量数量
bacthsz = 32# 学习率
lr = 1e-3# 迭代次数
epochs = 10# device = torch.device('cpu')
# if torch.cuda.is_available():
#     device = torch.device('cuda')# 设置固定随机初始值
torch.manual_seed(1234)# 训练集
train_db = Pokemon('pokemon', 224, mode='train')
train_loader = DataLoader(train_db, batch_size=bacthsz, shuffle=True, num_workers=4)# 验证集
val_db = Pokemon('pokemon', 224, mode='val')
val_loader = DataLoader(val_db, batch_size=bacthsz, num_workers=2)# 测试集
test_db = Pokemon('pokemon', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=bacthsz, num_workers=2)def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
# x, y = x.to(device), y.to(device)
# x:[b , c , h , w] , y:[b]
# out:[b,class_num]
with torch.no_grad():
out = model(x)
pred = out.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()return correct / totaldef main():
# model = ResNet18(5).to(device)
model = ResNet18(5)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()# 用于保存最高精度
best_acc = 0
best_epoch = 0
# 训练过程
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
# [b , c , h , w] , y[b]
# x, y = x.to(device), y.to(device)
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()# validation
if epoch % 2 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')
print('best acc:', best_acc, "best epoch:", best_epoch)# 测试过程
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)if __name__ == '__main__':
'''
best acc: 0.8969957081545065 best epoch: 8
loaded from ckpt!
test acc: 0.8931623931623932
'''
main()

四、迁移学习

将处理相类似信号(特别是数据量较大)的神经网络嫁接过来,应用到本实验中
在这里插入图片描述

  • 具体的嫁接过程
    ① 尽量保留网络前、中部分
    ② 去除最后一层,根据自己的分类任务定制最后一层
    在这里插入图片描述

4.1 pytorch实现迁移学习

from torchvision.models import resnet18model = resnet18(pretrained=True)
# 17 layer out:[32, 512, 1, 1]
model = nn.Sequential(*list(model.children())[:-1],
utils.Flatten(),# 降维度
nn.Linear(512, 5))

这篇关于第十二章 迁移学习-实战宝可梦精灵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1095808

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学