pytorch11:模型加载与保存、finetune迁移训练

2024-01-12 12:04

本文主要是介绍pytorch11:模型加载与保存、finetune迁移训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

目录

  • 一、模型加载与保存
    • 1.1 序列化与反序列化概念
    • 1.2 pytorch中的序列化与反序列化
    • 1.3 模型保存的两种方法
    • 1.4 模型加载两种方法
  • 二、断点训练
    • 2.1 断点保存代码
    • 2.2 断点恢复代码
  • 三、finetune
    • 3.1 迁移学习
    • 3.2 模型的迁移学习
    • 3.2 模型微调步骤
      • 3.2.1 模型微调步骤
      • 3.2.2 模型微调训练方法
    • 3.3 迁移训练实验

一、模型加载与保存

1.1 序列化与反序列化概念

序列化是将数据结构或对象转换为可以存储或传输的格式的过程,而反序列化则是将存储或传输的数据重新转换为数据结构或对象的过程。
在计算机科学中,序列化和反序列化通常用于数据持久化、网络传输和进程间通信等场景。以下是对序列化和反序列化的详细解读:

  1. 序列化:
    • 序列化的过程将数据结构或对象转换为字节流或其他格式,以便在存储或传输时能够被保存下来或发送出去。这通常涉及将数据结构中的字段和属性转换为二进制码或文本格式,以便能够被存储在文件中或通过网络传输。
    • 序列化的结果可以是二进制数据、JSON、XML等格式,不同的数据类型和应用场景可能采用不同的序列化格式。
    • 序列化的过程可以包括将对象进行扁平化、编码、压缩等操作,以便提高存储和传输的效率和安全性。
  2. 反序列化:
    • 反序列化的过程是将序列化后的数据重新转换为原始的数据结构或对象,使得在存储或传输后能够恢复原来的数据格式和内容。
    • 反序列化的过程需要根据序列化时采用的格式和规则,对序列化后的数据进行解码、解压缩等操作,最终还原为原始的数据结构或对象。
    • 反序列化的过程需要确保数据的完整性和正确性,以及适当地处理可能存在的异常和错误情况。
      在实际应用中,序列化和反序列化广泛应用于各种领域,如数据库持久化、分布式系统通信、缓存存储、远程过程调用等。常见的序列化和反序列化技术包括JSON、XML、Protocol Buffers、Thrift等,它们在不同的场景和需求下有着不同的优势和适用性。

通过序列化技术,将内存中的数据存储到硬盘,在需要使用的时候通过反序列化的方法转化成可读取数据。
在这里插入图片描述

1.2 pytorch中的序列化与反序列化

  1. torch.save(序列化):用于保存模型
    主要参数:
    • obj:对象
    • f:输出路径
  2. torch.load(反序列化):用于加载模型
    主要参数
    • f:文件路径
    • map_location:指定存放位置, cpu or gpu

1.3 模型保存的两种方法

方法1:保存整个Module模型
torch.save(net, path)
方法2:保存模型参数parameter
state_dict = net.state_dict()
torch.save(state_dict , path)

使用方法1会比较耗时耗费资源,通常我们会使用方法2,只保存模型训练过程中的参数。

代码实现:

import torch
import numpy as np
import torch.nn as nnclass LeNet2(nn.Module):def __init__(self, classes):super(LeNet2, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return xdef initialize(self):for p in self.parameters():p.data.fill_(2024111)net = LeNet2(classes=2024)# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()  #模型模型训练参数改变
print("训练后: ", net.features[0].weight[0, ...])path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"# 保存整个模型
torch.save(net, path_model)# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

输出结果:
在这里插入图片描述
在这里插入图片描述

1.4 模型加载两种方法

方法1:加载模型
代码实现:

# ================================== load net ===========================
flag = 1
# flag = 0
if flag:path_model = "./model.pkl"net_load = torch.load(path_model)print(net_load)

输出结果:
在这里插入图片描述
在这里插入图片描述

方法2: 加载参数
代码实现:

# ================================== load state_dict ===========================flag = 1
# flag = 0
if flag:path_state_dict = "./model_state_dict.pkl"state_dict_load = torch.load(path_state_dict)print(state_dict_load.keys())

输出结果:
将保存的参数名称打印出来;
在这里插入图片描述
方法3:将参数加载到新的模型当中

# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:net_new = LeNet2(classes=2019)print("加载前: ", net_new.features[0].weight[0, ...])net_new.load_state_dict(state_dict_load)print("加载后: ", net_new.features[0].weight[0, ...])

输出结果:
在这里插入图片描述

以上完整代码:

# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
class LeNet2(nn.Module):def __init__(self, classes):super(LeNet2, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return xdef initialize(self):for p in self.parameters():p.data.fill_(20191104)
# ================================== load net ===========================
# flag = 1
flag = 0
if flag:path_model = "./model.pkl"net_load = torch.load(path_model)print(net_load)
# ================================== load state_dict ===========================
flag = 1
# flag = 0
if flag:path_state_dict = "./model_state_dict.pkl"state_dict_load = torch.load(path_state_dict)print(state_dict_load.keys())
# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:net_new = LeNet2(classes=2024)print("加载前: ", net_new.features[0].weight[0, ...])net_new.load_state_dict(state_dict_load)print("加载后: ", net_new.features[0].weight[0, ...])

二、断点训练

首先我们需要确定模型训练过程中哪些参数是会一直发生变化的,模型中的权值以及优化器中的可学习参数是一直发生变化的,数据以及损失函数是保持不变的。
在这里插入图片描述
断点训练函数方法:
在这里插入图片描述

2.1 断点保存代码

当训练到第5次的时候我们进行人为中断训练,将当前训练阶段的模型权值参数、优化器参数、训练轮数保存到checkpoint

    if (epoch+1) % checkpoint_interval == 0:  # checkpoint_interval初始值设置为5checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch}path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)if epoch > 5:print("训练意外中断...")break

输出结果:
在这里插入图片描述

2.2 断点恢复代码

加载上一次训练相关参数数据

# ============================ step 5+/5 断点恢复 ============================
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数
start_epoch = checkpoint['epoch']  # 加载上一次训练轮数
scheduler.last_epoch = start_epoch  # 学习率策略更新

输出结果:
当前训练初始轮数从第5轮开始训练,所以精度可以很快增加。
在这里插入图片描述
完整代码展示

save_checkpoint.py;保存断点参数数据

# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
import syshello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvisionset_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")if not os.path.exists(split_dir):raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.8),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)  # 设置学习率下降策略# ============================ step 5+/5 断点恢复 ============================path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数start_epoch = checkpoint['epoch']  # 加载上一次训练轮数scheduler.last_epoch = start_epoch  # 学习率策略更新# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i + 1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率if (epoch + 1) % checkpoint_interval == 0:checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dic": optimizer.state_dict(),"loss": loss,"epoch": epoch}path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)# if epoch > 5:#     print("训练意外中断...")#     break# validate the modelif (epoch + 1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss.item())print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val / len(valid_loader), correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

checkpoint_resume.py:加载断点参数数据

# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
import syshello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvisionset_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")if not os.path.exists(split_dir):raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.8),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)  # 设置学习率下降策略# ============================ step 5+/5 断点恢复 ============================path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数start_epoch = checkpoint['epoch']  # 加载上一次训练轮数scheduler.last_epoch = start_epoch  # 学习率策略更新# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i + 1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率if (epoch + 1) % checkpoint_interval == 0:checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dic": optimizer.state_dict(),"loss": loss,"epoch": epoch}path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)# if epoch > 5:#     print("训练意外中断...")#     break# validate the modelif (epoch + 1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss.item())print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val / len(valid_loader), correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

三、finetune

3.1 迁移学习

首先了解一下迁移学习(Transfer Learning)概念:它机器学习分支,研究源域(source domain)的知识如何应用到目标域(target
domain),来提高模型的性能。
在这里插入图片描述

图(a)是传统的机器学习过程,针对某一个任务进行网络模型训练
图(b)是迁移学习过程,通过对源任务进行模型训练得到一个"知识",当我们需要训练一个新的任务时,可以在源任务训练的"知识"上继续进行训练,从而得到target模型;

3.2 模型的迁移学习

假设我们已经训练好一个模型了,我们把网络训练过程中的权值当做"知识",当我想要再训练一个新的模型任务时,但是数据量较小,不足以训练一个较好的模型,我们把上一个模型的知识应用到新的任务当中,这就是模型的迁移训练,从而提高模型的精度和效果。就好比一个人学会了一门乐器之后,已经掌握了相关乐理知识,再让他去学习另外一门乐器就会更加容易学习!!!在这里插入图片描述

3.2 模型微调步骤

通常我们会找到模型训练过程中具有相同共性的部分,例如下面这个神经网络,分为两部分,分别是特征提取部分feature和图像分类部分classifier两个部分,当我们需要进行其他图像分类任务的时候,我们可以保留图像特征提取部分,改变分类部分的output类别数。在这里插入图片描述

3.2.1 模型微调步骤

  1. 获取预训练模型参数:源任务当中学习到的"知识"。
  2. 加载模型(load_state_dict):将知识加载到新的模型当中。
  3. 修改输出层,不同任务输出层类别数不同。

3.2.2 模型微调训练方法

  1. 固定预训练的参数(固定参数的方法:requires_grad =False;lr=0)。
  2. Features Extractor较小学习率(params_group),不同的参数组设置不同的学习率,例如特征提取模块我们希望它变动不大,可以设置较小的学习率,在特征分类模块可以设置较大的学习率。

3.3 迁移训练实验

1、数据准备
Finetune Resnet -18 用于二分类,蚂蚁蜜蜂二分类数据,训练集:各120~张 验证集:各70~张
下载Resnet -18预训练模型,下载地址:https://download.pytorch.org/models/resnet18-5c106cde.pth
在这里插入图片描述

2、实验结果
在经过25轮训练之后,不使用预训练模型的精度提升的很慢,但是使用了预训练模型之后精度可以快速上升。
在这里插入图片描述


数据处理模块,用于获取蜜蜂和蚂蚁图片路径以及文件夹标签:

class AntsDataset(Dataset):def __init__(self, data_dir, transform=None):# 初始化函数,接收数据目录和数据变换操作self.label_name = {"ants": 0, "bees": 1}  # 定义标签对应的字典self.data_info = self.get_img_info(data_dir)  # 获取图片信息self.transform = transform  # 保存数据变换操作def __getitem__(self, index):# 获取指定索引处的数据path_img, label = self.data_info[index]  # 获取图片路径和标签img = Image.open(path_img).convert('RGB')  # 打开图片并转换为RGB格式if self.transform is not None:img = self.transform(img)  # 对图片进行数据变换return img, label  # 返回处理后的图片和标签def __len__(self):# 返回数据集的长度return len(self.data_info)def get_img_info(self, data_dir):# 获取图片信息的函数data_info = list()  # 创建空列表用于保存图片信息for root, dirs, _ in os.walk(data_dir):# 遍历数据目录中的子目录for sub_dir in dirs:img_names = os.listdir(os.path.join(root, sub_dir))  # 获取子目录下的文件列表img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))  # 筛选出.jpg格式的文件名# 遍历该子目录下的图片for i in range(len(img_names)):img_name = img_names[i]  # 获取图片文件名path_img = os.path.join(root, sub_dir, img_name)  # 构建完整的图片路径label = self.label_name[sub_dir]  # 获取图片的标签data_info.append((path_img, int(label)))  # 将图片路径和标签添加到数据信息列表中if len(data_info) == 0:raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))  # 若数据信息列表为空,则抛出异常return data_info  # 返回图片信息列表

加载预训练模型模块:

flag = 1
if flag:path_pretrained_model = os.path.join("finetune_resnet18-5c106cde.pth")if not os.path.exists(path_pretrained_model):raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip\n放到 {}下,并解压即可".format(path_pretrained_model, os.path.dirname(path_pretrained_model)))state_dict_load = torch.load(path_pretrained_model)resnet18_ft.load_state_dict(state_dict_load)

冻结网络层方法

# 冻结所有网络层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:for param in resnet18_ft.parameters():param.requires_grad = Falseprint("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.fc.weight[0, 0, ...]))

替换fc层,因为我们是2分类模型,所以需要修改分类网络;

# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features  # 获取原网络中的特征输入
resnet18_ft.fc = nn.Linear(num_ftrs, classes)  #classes=2

网络分组模块,我们希望特征提取部分参数更新小一些,分类部分参数更新大一些;

if flag:# 将网络划分为两个参数组fc_params_id = list(map(id, resnet18_ft.fc.parameters()))"""这一行首先使用resnet18_ft.fc.parameters()获取了ResNet18模型中全连接层的参数,然后通过map(id, ...)将每个参数的内存地址映射为一个列表。这样得到的fc_params_id列表包含了全连接层参数的内存地址。"""base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())"""这一行使用filter函数和lambda表达式来过滤resnet18_ft模型中不属于全连接层的参数。具体来说,filter函数通过lambda p: id(p) not in fc_params_id对resnet18_ft.parameters()中的参数进行过滤,保留那些内存地址不在fc_params_id列表中的参数。这样就得到了base_params,其中包含了除全连接层参数外的其他层的参数,也就是特征提取部分。"""optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0},  # 卷积层的学习率设置小一些,如果设置为0的话,会直接冻结卷积层{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器

完整模型训练代码如下:

# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from PIL import Image
import sys
import randomhello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
import torchvision.models as models
import torchvisionBASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))# =====================参数设置=====================
def set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 5# =======================读取图片数据==============================
class AntsDataset(Dataset):def __init__(self, data_dir, transform=None):# 初始化函数,接收数据目录和数据变换操作self.label_name = {"ants": 0, "bees": 1}  # 定义标签对应的字典self.data_info = self.get_img_info(data_dir)  # 获取图片信息self.transform = transform  # 保存数据变换操作def __getitem__(self, index):# 获取指定索引处的数据path_img, label = self.data_info[index]  # 获取图片路径和标签img = Image.open(path_img).convert('RGB')  # 打开图片并转换为RGB格式if self.transform is not None:img = self.transform(img)  # 对图片进行数据变换return img, label  # 返回处理后的图片和标签def __len__(self):# 返回数据集的长度return len(self.data_info)def get_img_info(self, data_dir):# 获取图片信息的函数data_info = list()  # 创建空列表用于保存图片信息for root, dirs, _ in os.walk(data_dir):# 遍历数据目录中的子目录for sub_dir in dirs:img_names = os.listdir(os.path.join(root, sub_dir))  # 获取子目录下的文件列表img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))  # 筛选出.jpg格式的文件名# 遍历该子目录下的图片for i in range(len(img_names)):img_name = img_names[i]  # 获取图片文件名path_img = os.path.join(root, sub_dir, img_name)  # 构建完整的图片路径label = self.label_name[sub_dir]  # 获取图片的标签data_info.append((path_img, int(label)))  # 将图片路径和标签添加到数据信息列表中if len(data_info) == 0:raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))  # 若数据信息列表为空,则抛出异常return data_info  # 返回图片信息列表# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "..", "data", "hymenoptera_data"))
if not os.path.exists(data_dir):raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(data_dir, os.path.dirname(data_dir)))train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================# 1/3 构建模型
resnet18_ft = models.resnet18()# 2/3 加载模型参数
# flag = 0
flag = 1
if flag:path_pretrained_model = os.path.join("finetune_resnet18-5c106cde.pth")if not os.path.exists(path_pretrained_model):raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip\n放到 {}下,并解压即可".format(path_pretrained_model, os.path.dirname(path_pretrained_model)))state_dict_load = torch.load(path_pretrained_model)resnet18_ft.load_state_dict(state_dict_load)# 冻结所有网络层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:for param in resnet18_ft.parameters():param.requires_grad = Falseprint("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.fc.weight[0, 0, ...]))# 冻结卷积层
flag_c = 0
# flag_c = 1
if flag_c:for name, param in resnet18_ft.named_parameters():if "fc" in name:  # 如果参数名中不包含"fc",即不是全连接层的参数param.requires_grad = False# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features  # 获取原网络中的特征输入
resnet18_ft.fc = nn.Linear(num_ftrs, classes)resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
flag = 0
# flag = 1
if flag:# 将网络划分为两个参数组fc_params_id = list(map(id, resnet18_ft.fc.parameters()))"""这一行首先使用resnet18_ft.fc.parameters()获取了ResNet18模型中全连接层的参数,然后通过map(id, ...)将每个参数的内存地址映射为一个列表。这样得到的fc_params_id列表包含了全连接层参数的内存地址。"""base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())"""这一行使用filter函数和lambda表达式来过滤resnet18_ft模型中不属于全连接层的参数。具体来说,filter函数通过lambda p: id(p) not in fc_params_id对resnet18_ft.parameters()中的参数进行过滤,保留那些内存地址不在fc_params_id列表中的参数。这样就得到了base_params,其中包含了除全连接层参数外的其他层的参数,也就是特征提取部分。"""optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0},  # 卷积层的学习率设置小一些,如果设置为0的话,会直接冻结卷积层{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器# optimizer = optim.Adam(resnet18_ft.parameters(), lr=0.01)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.resnet18_ft.train()for i, data in enumerate(train_loader):# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = resnet18_ft(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i + 1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))loss_mean = 0.# if flag_m1:print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.fc.weight[0, 0, ...]))scheduler.step()  # 更新学习率# validate the modelif (epoch + 1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.resnet18_ft.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = resnet18_ft(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val / len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

在这里插入图片描述

这篇关于pytorch11:模型加载与保存、finetune迁移训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/597811

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

Flutter 进阶:绘制加载动画

绘制加载动画:由小圆组成的大圆 1. 定义 LoadingScreen 类2. 实现 _LoadingScreenState 类3. 定义 LoadingPainter 类4. 总结 实现加载动画 我们需要定义两个类:LoadingScreen 和 LoadingPainter。LoadingScreen 负责控制动画的状态,而 LoadingPainter 则负责绘制动画。