第十二章 迁移学习-实战宝可梦精灵

2024-08-22 08:52

本文主要是介绍第十二章 迁移学习-实战宝可梦精灵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、Pokemon数据集
    • 1.1 数据集收集
    • 1.2 数据集划分
    • 1.3 数据集加载
    • 1.4 数据预处理
    • 1.5 pytorch自定义数据库实现
  • 二、ResNet网络搭建
  • 三、训练与测试
  • 四、迁移学习
    • 4.1 pytorch实现迁移学习

一、Pokemon数据集

1.1 数据集收集

在这里插入图片描述

# git下载
git lfs install
git clone https://www.modelscope.cn/datasets/ModelBulider/pokemon.git

1.2 数据集划分

在这里插入图片描述


1.3 数据集加载

在这里插入图片描述

  • 加载数据
    ① 继承 torch.utils.data.Dataset
    ② 实现 __len__ 函数,其返回数据集的数量(整型数字)
    ③ 实现 __getitem__函数,根据索引值返回一个数据
    在这里插入图片描述

举例:
在这里插入图片描述


1.4 数据预处理

将尺寸大小不一致的数据(图片)预处理为大小一致的1数据
② 数据增强(旋转、裁剪等)
③ 归一化(均值、方差)
④ 转换为 Tensor 数据类型
在这里插入图片描述


1.5 pytorch自定义数据库实现

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: code - pokemon.py
@author: yonghao
@Description: 
@since 2021/03/01 19:41
'''
from visdom import Visdom
import time
import torch
import os, glob
import random, csv
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoaderroot = 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon'class Pokemon(Dataset):
def __init__(self, root, resize, mode='train'):
'''
初始化数据集
:paramroot: 图片存储的位置
:paramresize: 重新编辑图片的尺寸
:parammode: 初始化图片的类型(可以是数据集中各中分类)
'''
super(Pokemon, self).__init__()
self.root = rootself.resize = resizeself.mode = modeself.name2label = {}
# 创建 类名-> label 的映射字典
# os.listdir()每次顺序都不一样,故使用sorted()排序,使 类名-> label 的映射字典固定
for name in sorted(os.listdir(os.path.join(root))):
# 只读取文件夹名
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label)
self.images, self.labels = self.load_csv('images.csv')
# 根据mode设定数据集的比例
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # 20%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]def __len__(self):
return len(self.images)def __getitem__(self, item) -> tuple:
# item ~ [0,len(images)-1]
# self.images , self.labels
# image , label
img, label = self.images[item], self.labels[item]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(1.25 * self.resize), int(1.25 * self.resize))), # 调整尺寸
transforms.RandomRotation(15), # 旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# 注意transforms.Normalize() 应该在transforms.ToTensor() 后面
# 数据在通道层上归一化,会使变化图片的像素
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 返回由img,label 组成的Tensor 元组
img = tf(img)
label = torch.tensor(label)return img, labeldef denormalize(self, x_het):
'''
图像逆正则化显示
:paramx_het: 正则化后的数据
:return:
'''# x_het = (x - mean) / std
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
# x = x_het * std + mean
# x:[channel , h , w] , mean:[3] -> [3,1,1] , std:[3] -> [3,1,1]
mean = mean.unsqueeze(dim=-1).unsqueeze(dim=-1)std = std.unsqueeze(dim=-1).unsqueeze(dim=-1)x = x_het * std + meanreturn xdef load_csv(self, filename):
'''
加载图片数据 与 其label数据
:paramfilename: 加载数据的文件名
:return:
'''
# 仅在第一次调用时创建csv文件,保存 图片路径——>label 的映射关系
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
'''
python在模块glob中定义了glob()函数,实现了对目录内容进行匹配的功能,
glob.glob()函数接受通配模式作为输入,并返回所有匹配的文件名和路径名列表
与os.listdir类似
'''
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 1167 , 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon\\bulbasaur\\00000000.png'# 打乱的是图片的存储路径
random.shuffle(images)# 使用上下文管理,对文件进行操作
'''
with是从Python2.5引入的一个新的语法,它是一种上下文管理协议,目的在于从流程图中把try,except 和finally 关键字和资源分配释放相关代码统统去掉,简化try….except….finlally的处理流程。with通过__enter__方法初始化,然后在__exit__中做善后以及处理异常。所以使用with处理的对象必须有__enter__()和__exit__()这两个方法。其中__enter__()方法在语句体(with语句包裹起来的代码块)执行之前进入运行,__exit__()方法在语句体执行完毕退出后运行。with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。紧跟with后面的语句会被求值,返回对象的__enter__()方法被调用,这个方法的返回值将被赋值给as关键字后面的变量,当with后面的代码块全部被执行完之后,将调用前面返回对象的__exit__()方法
'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
# os.sep 为系统自动识别的文件路径分隔符
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])images, labels = [], []
with open(os.path.join(root, filename), mode='r') as f:
reader = csv.reader(f)
for row in reader:
img, label = rowimages.append(img)
labels.append(int(label))assert len(images) == len(labels)return images, labelsdef main():
vis = Visdom()
# 获取数据集(单个数据做返回)
db = Pokemon(root, 64, mode='train')
img, label = next(iter(db))
print('sample:', img.shape, label.shape)
vis.image(img, win='img_win_het', opts=dict(title='norm_img_show'))
vis.image(db.denormalize(img), win='img_win', opts=dict(title='img_show'))# 批量导出数据
loader = DataLoader(db, batch_size=32, shuffle=True)
for x, y in loader:
vis.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
vis.text(str(y.numpy()), win='label', opts=dict(title='bacth-y'))
time.sleep(10)if __name__ == '__main__':
main()

二、ResNet网络搭建

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: 实战代码- resnet.py
@author: yonghao
@Description: 创建残差网络结构
@since 2021/03/01 17:51
'''
import torch
import torch.nn.functional as F
from torch import nn
import utilsclass ResBlk(nn.Module):
'''
创建ResBlock
'''def __init__(self, ch_in, ch_out, stride=1):
'''
创建ResBlock模块
:paramch_in: 输入的通道数
:paramch_out: 输出的通道数
:paramstride: 卷积步长
'''
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
if ch_in == ch_out:
self.extra = nn.Sequential()
else:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))out = out + self.extra(x)
out = F.relu(out)
return outclass ResNet18(nn.Module):def __init__(self, num_class):
'''
创建18层的ResNet
:paramnum_class:分类数量
'''
super(ResNet18, self).__init__()self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=2),
nn.BatchNorm2d(16)
)# followed 4 blocks
# [b , 16 , h , w] => [b , 32 , h , w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b , 32 , h , w] => [b , 64 , h , w]
self.blk2 = ResBlk(32, 64, stride=3)
# [b , 64 , h , w] => [b , 128 , h , w]
self.blk3 = ResBlk(64, 128, stride=2)
# [b , 128 , h , w] => [b , 256 , h , w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b , 256 , h , 2] => [b , 256*h*w]
self.flat = utils.Flatten()
# [b , 256*h*w] => [b , num_class]
self.out_layer = nn.Linear(256 * 3 * 3, num_class)def forward(self, x):
x = F.relu(self.conv1(x), inplace=True)
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = self.flat(x)
out = self.out_layer(x)
return outdef mian():
# 测试ResBlk,当ch_in==ch_out时正确
# 当ch_in==ch_out时报异常
blk = ResBlk(64, 128, stride=2)
tmp = torch.randn(2, 64, 64, 64)
out = blk(tmp)
print('block:', out.shape)model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print("resnet:", out.shape)
p = sum([i.numel() for i in model.parameters()])
print('parameters size:', p)if __name__ == '__main__':
mian()

三、训练与测试

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: project_code - process.py
@author: yonghao
@Description: 实现训练过程 与 测试过程
@since 2021/03/02 18:54
'''
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from model.resnet import ResNet18
from pokemon import Pokemon# 批量数量
bacthsz = 32# 学习率
lr = 1e-3# 迭代次数
epochs = 10# device = torch.device('cpu')
# if torch.cuda.is_available():
#     device = torch.device('cuda')# 设置固定随机初始值
torch.manual_seed(1234)# 训练集
train_db = Pokemon('pokemon', 224, mode='train')
train_loader = DataLoader(train_db, batch_size=bacthsz, shuffle=True, num_workers=4)# 验证集
val_db = Pokemon('pokemon', 224, mode='val')
val_loader = DataLoader(val_db, batch_size=bacthsz, num_workers=2)# 测试集
test_db = Pokemon('pokemon', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=bacthsz, num_workers=2)def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
# x, y = x.to(device), y.to(device)
# x:[b , c , h , w] , y:[b]
# out:[b,class_num]
with torch.no_grad():
out = model(x)
pred = out.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()return correct / totaldef main():
# model = ResNet18(5).to(device)
model = ResNet18(5)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()# 用于保存最高精度
best_acc = 0
best_epoch = 0
# 训练过程
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
# [b , c , h , w] , y[b]
# x, y = x.to(device), y.to(device)
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()# validation
if epoch % 2 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')
print('best acc:', best_acc, "best epoch:", best_epoch)# 测试过程
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)if __name__ == '__main__':
'''
best acc: 0.8969957081545065 best epoch: 8
loaded from ckpt!
test acc: 0.8931623931623932
'''
main()

四、迁移学习

将处理相类似信号(特别是数据量较大)的神经网络嫁接过来,应用到本实验中
在这里插入图片描述

  • 具体的嫁接过程
    ① 尽量保留网络前、中部分
    ② 去除最后一层,根据自己的分类任务定制最后一层
    在这里插入图片描述

4.1 pytorch实现迁移学习

from torchvision.models import resnet18model = resnet18(pretrained=True)
# 17 layer out:[32, 512, 1, 1]
model = nn.Sequential(*list(model.children())[:-1],
utils.Flatten(),# 降维度
nn.Linear(512, 5))

这篇关于第十二章 迁移学习-实战宝可梦精灵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1095808

相关文章

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

在不同系统间迁移Python程序的方法与教程

《在不同系统间迁移Python程序的方法与教程》本文介绍了几种将Windows上编写的Python程序迁移到Linux服务器上的方法,包括使用虚拟环境和依赖冻结、容器化技术(如Docker)、使用An... 目录使用虚拟环境和依赖冻结1. 创建虚拟环境2. 冻结依赖使用容器化技术(如 docker)1. 创

在Java中使用ModelMapper简化Shapefile属性转JavaBean实战过程

《在Java中使用ModelMapper简化Shapefile属性转JavaBean实战过程》本文介绍了在Java中使用ModelMapper库简化Shapefile属性转JavaBean的过程,对比... 目录前言一、原始的处理办法1、使用Set方法来转换2、使用构造方法转换二、基于ModelMapper

Java实战之自助进行多张图片合成拼接

《Java实战之自助进行多张图片合成拼接》在当今数字化时代,图像处理技术在各个领域都发挥着至关重要的作用,本文为大家详细介绍了如何使用Java实现多张图片合成拼接,需要的可以了解下... 目录前言一、图片合成需求描述二、图片合成设计与实现1、编程语言2、基础数据准备3、图片合成流程4、图片合成实现三、总结前

SQL Server数据库迁移到MySQL的完整指南

《SQLServer数据库迁移到MySQL的完整指南》在企业应用开发中,数据库迁移是一个常见的需求,随着业务的发展,企业可能会从SQLServer转向MySQL,原因可能是成本、性能、跨平台兼容性等... 目录一、迁移前的准备工作1.1 确定迁移范围1.2 评估兼容性1.3 备份数据二、迁移工具的选择2.1

nginx-rtmp-module构建流媒体直播服务器实战指南

《nginx-rtmp-module构建流媒体直播服务器实战指南》本文主要介绍了nginx-rtmp-module构建流媒体直播服务器实战指南,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. RTMP协议介绍与应用RTMP协议的原理RTMP协议的应用RTMP与现代流媒体技术的关系2

C语言小项目实战之通讯录功能

《C语言小项目实战之通讯录功能》:本文主要介绍如何设计和实现一个简单的通讯录管理系统,包括联系人信息的存储、增加、删除、查找、修改和排序等功能,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录功能介绍:添加联系人模块显示联系人模块删除联系人模块查找联系人模块修改联系人模块排序联系人模块源代码如下

将sqlserver数据迁移到mysql的详细步骤记录

《将sqlserver数据迁移到mysql的详细步骤记录》:本文主要介绍将SQLServer数据迁移到MySQL的步骤,包括导出数据、转换数据格式和导入数据,通过示例和工具说明,帮助大家顺利完成... 目录前言一、导出SQL Server 数据二、转换数据格式为mysql兼容格式三、导入数据到MySQL数据

Golang操作DuckDB实战案例分享

《Golang操作DuckDB实战案例分享》DuckDB是一个嵌入式SQL数据库引擎,它与众所周知的SQLite非常相似,但它是为olap风格的工作负载设计的,DuckDB支持各种数据类型和SQL特性... 目录DuckDB的主要优点环境准备初始化表和数据查询单行或多行错误处理和事务完整代码最后总结Duck