百度飞浆ResNet50大模型微调实现十二种猫图像分类

2023-10-10 05:44

本文主要是介绍百度飞浆ResNet50大模型微调实现十二种猫图像分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Imageimport matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224],                     # 输入图片的shape"class_dim": 12,                                 # 分类数"src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径"target_path":"/home/aistudio/data/dataset",     # 要解压的路径 "train_list_path": "./train.txt",                # train_data.txt路径"eval_list_path": "./eval.txt",                  # eval_data.txt路径"label_dict":{},                                 # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6,                                 # 训练轮数"train_batch_size": 16,                          # 批次的大小"learning_strategy": {                           # 优化函数相关的配置"lr": 0.0005                                  # 超参数学习率} 
}scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)

 解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

 这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练 

if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')

 使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))

设置eval模式,使用批处理测试准确率 

这篇关于百度飞浆ResNet50大模型微调实现十二种猫图像分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/178364

相关文章

python使用watchdog实现文件资源监控

《python使用watchdog实现文件资源监控》watchdog支持跨平台文件资源监控,可以检测指定文件夹下文件及文件夹变动,下面我们来看看Python如何使用watchdog实现文件资源监控吧... python文件监控库watchdogs简介随着Python在各种应用领域中的广泛使用,其生态环境也

el-select下拉选择缓存的实现

《el-select下拉选择缓存的实现》本文主要介绍了在使用el-select实现下拉选择缓存时遇到的问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录项目场景:问题描述解决方案:项目场景:从左侧列表中选取字段填入右侧下拉多选框,用户可以对右侧

Python pyinstaller实现图形化打包工具

《Pythonpyinstaller实现图形化打包工具》:本文主要介绍一个使用PythonPYQT5制作的关于pyinstaller打包工具,代替传统的cmd黑窗口模式打包页面,实现更快捷方便的... 目录1.简介2.运行效果3.相关源码1.简介一个使用python PYQT5制作的关于pyinstall

使用Python实现大文件切片上传及断点续传的方法

《使用Python实现大文件切片上传及断点续传的方法》本文介绍了使用Python实现大文件切片上传及断点续传的方法,包括功能模块划分(获取上传文件接口状态、临时文件夹状态信息、切片上传、切片合并)、整... 目录概要整体架构流程技术细节获取上传文件状态接口获取临时文件夹状态信息接口切片上传功能文件合并功能小

python实现自动登录12306自动抢票功能

《python实现自动登录12306自动抢票功能》随着互联网技术的发展,越来越多的人选择通过网络平台购票,特别是在中国,12306作为官方火车票预订平台,承担了巨大的访问量,对于热门线路或者节假日出行... 目录一、遇到的问题?二、改进三、进阶–展望总结一、遇到的问题?1.url-正确的表头:就是首先ur

C#实现文件读写到SQLite数据库

《C#实现文件读写到SQLite数据库》这篇文章主要为大家详细介绍了使用C#将文件读写到SQLite数据库的几种方法,文中的示例代码讲解详细,感兴趣的小伙伴可以参考一下... 目录1. 使用 BLOB 存储文件2. 存储文件路径3. 分块存储文件《文件读写到SQLite数据库China编程的方法》博客中,介绍了文

Redis主从复制实现原理分析

《Redis主从复制实现原理分析》Redis主从复制通过Sync和CommandPropagate阶段实现数据同步,2.8版本后引入Psync指令,根据复制偏移量进行全量或部分同步,优化了数据传输效率... 目录Redis主DodMIK从复制实现原理实现原理Psync: 2.8版本后总结Redis主从复制实

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

基于Python实现PDF动画翻页效果的阅读器

《基于Python实现PDF动画翻页效果的阅读器》在这篇博客中,我们将深入分析一个基于wxPython实现的PDF阅读器程序,该程序支持加载PDF文件并显示页面内容,同时支持页面切换动画效果,文中有详... 目录全部代码代码结构初始化 UI 界面加载 PDF 文件显示 PDF 页面页面切换动画运行效果总结主

SpringBoot实现基于URL和IP的访问频率限制

《SpringBoot实现基于URL和IP的访问频率限制》在现代Web应用中,接口被恶意刷新或暴力请求是一种常见的攻击手段,为了保护系统资源,需要对接口的访问频率进行限制,下面我们就来看看如何使用... 目录1. 引言2. 项目依赖3. 配置 Redis4. 创建拦截器5. 注册拦截器6. 创建控制器8.