利用pytorch实现迁移学习之猫狗分类器(dog vs cat)

2024-02-11 07:08

本文主要是介绍利用pytorch实现迁移学习之猫狗分类器(dog vs cat),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

迁移学习

迁移学习(Transfer learning) 就是把已学训练好的模型参数迁移到新的模型来帮助新模型训练。考虑到大部分数据或任务是存在相关性的,所以通过迁移学习我们可以将已经学到的模型参数(也可理解为模型学到的知识)通过某种方式来分享给新模型从而加快并优化模型的学习效率不用像大多数网络那样从零学习。
本文使用VGG16模型用于迁移学习,最终得到一个能对猫狗图片进行辨识的CNN(卷积神经网络),测试集用来验证模型是否能够很好的工作。

猫狗分类器

本文使用迁移学习实现猫狗分类器。
数据集来自Kaggle比赛:Dogs vs. Cats Redux: Kernels Edition

利用pytorch实现迁移学习

首先进行图片的导入和预览

path = "dog_vs_cat"
transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])data_image = {x:datasets.ImageFolder(root = os.path.join(path,x),transform = transform)for x in ["train", "val"]}data_loader_image = {x:torch.utils.data.DataLoader(dataset=data_image[x],batch_size = 4,shuffle = True)for x in ["train", "val"]}

输入的图片需要分辨率为224*224,所以使用transform.CenterCrop(224)对原始图片进行裁剪。载入的图片训练集合为20000个和验证集合为5000个,原始图片全部为训练集合,需自己拆分出一部分验证集合,输出的Label,1代表为狗,0代表猫。

X_train,y_train = next(iter(data_loader_image["train"]))
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
img = torchvision.utils.make_grid(X_train)
img = img.numpy().transpose((1,2,0))
img = img*std + meanprint([classes[i] for i in y_train])
plt.imshow(img)

图片预览
训练集的图片都为2242243。
迁移模型,打印出原始VGG模型结构为:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

迁移过来的VGG16模型需适应新的需求,达到对猫狗图片很好的识别,因此改写VGG16的全连接层的最后一部分并且重新训练参数。
即使只是训练整个全连接层的全部参数,普通的电脑也会花费大量的时间,所以这里只训练全连接层的最后一层,就能达到很好的效果。

model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096, 4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096, 2))for parma in model.parameters():parma.requires_grad = Falsefor index, parma in enumerate(model.classifier.parameters()):if index == 6:parma.requires_grad = Trueif use_gpu:model = model.cuda()cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters())

parma.requires_grid = False是冻结参数,即使发生新的训练也不会进行参数的更新。
这里还对全连接层的最后一层进行了改写,torch.nn.Linear(4096, 2)使得最后输出的结果只有两个,即只需要对猫狗进行分辨。
optimizer = torch.optim.Adam(model.classifier.parameters())只对全连接层参数进行更新优化,loss计算依然使用交叉熵。
对改写后的模型进行查看:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU()(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU()(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=2, bias=True))
)

这里我使用纯cpu进行训练,因为速度贼慢,我只进行了对100张图片进行训练的demo,进行1次训练的结果为:

Epoch0/1
----------
Batch 5, Train Loss:1.2027, Train ACC:90.0000
Batch 10, Train Loss:0.6853, Train ACC:92.0000
Batch 15, Train Loss:0.7109, Train ACC:91.0000
Batch 20, Train Loss:0.5332, Train ACC:93.0000
Batch 25, Train Loss:0.5215, Train ACC:94.0000
Batch 30, Train Loss:0.4346, Train ACC:95.0000
Batch 35, Train Loss:0.4213, Train ACC:95.0000
Batch 40, Train Loss:0.3748, Train ACC:95.0000
Batch 45, Train Loss:0.3541, Train ACC:95.0000
Batch 50, Train Loss:0.3501, Train ACC:94.0000
train Loss:0.3501, Correct:94.0000
val Loss:0.9151, Correct:88.0000
Training time is:6m 4s

看到训练的Loss为0.3501, Accuraty准确率为94%。验证集的Loss为0.9151,Accuraty准确率为88%。这只是100张图片的一次训练,更加多的图片以及多次的训练可能会得到一个更加好的结果。

随机输入测试集合产看预测结果:
预测结果
可以看到预测结果没有出现错误,本文输入时采用了随机裁剪,如果对原始图片进行缩放可能会提升模型的预测准确率,此外还可以增加数据个数、训练次数、数据增强处理。
完整代码链接:xiutangseeker/dog_vs_cat

这篇关于利用pytorch实现迁移学习之猫狗分类器(dog vs cat)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/699155

相关文章

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

Java枚举类实现Key-Value映射的多种实现方式

《Java枚举类实现Key-Value映射的多种实现方式》在Java开发中,枚举(Enum)是一种特殊的类,本文将详细介绍Java枚举类实现key-value映射的多种方式,有需要的小伙伴可以根据需要... 目录前言一、基础实现方式1.1 为枚举添加属性和构造方法二、http://www.cppcns.co

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java实现文件图片的预览和下载功能

《Java实现文件图片的预览和下载功能》这篇文章主要为大家详细介绍了如何使用Java实现文件图片的预览和下载功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... Java实现文件(图片)的预览和下载 @ApiOperation("访问文件") @GetMapping("

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义

opencv图像处理之指纹验证的实现

《opencv图像处理之指纹验证的实现》本文主要介绍了opencv图像处理之指纹验证的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、简介二、具体案例实现1. 图像显示函数2. 指纹验证函数3. 主函数4、运行结果三、总结一、