百度飞浆ResNet50大模型微调实现十二种猫图像分类

2023-10-12 19:01

本文主要是介绍百度飞浆ResNet50大模型微调实现十二种猫图像分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Imageimport matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224],                     # 输入图片的shape"class_dim": 12,                                 # 分类数"src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径"target_path":"/home/aistudio/data/dataset",     # 要解压的路径 "train_list_path": "./train.txt",                # train_data.txt路径"eval_list_path": "./eval.txt",                  # eval_data.txt路径"label_dict":{},                                 # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6,                                 # 训练轮数"train_batch_size": 16,                          # 批次的大小"learning_strategy": {                           # 优化函数相关的配置"lr": 0.0005                                  # 超参数学习率} 
}scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)

 解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

 这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练 

if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')

 使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))

设置eval模式,使用批处理测试准确率 

这篇关于百度飞浆ResNet50大模型微调实现十二种猫图像分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/197831

相关文章

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import

前端原生js实现拖拽排课效果实例

《前端原生js实现拖拽排课效果实例》:本文主要介绍如何实现一个简单的课程表拖拽功能,通过HTML、CSS和JavaScript的配合,我们实现了课程项的拖拽、放置和显示功能,文中通过实例代码介绍的... 目录1. 效果展示2. 效果分析2.1 关键点2.2 实现方法3. 代码实现3.1 html部分3.2

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操