GoogleNet网络训练集和测试集搭建

2024-04-18 11:28

本文主要是介绍GoogleNet网络训练集和测试集搭建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。

还是使用的花分类的数据集进行训练和测试的。

一、训练集

1、搭建网络

设置参数:使用辅助分类器,采用权重初始化

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)

2、参数输出

之前的模型只有 1 个输出,但由于GoogleNet使用了两个辅助分类器,所以会有 3 个输出。

定义三个输出,分别计算主分类器、辅助分类器1、辅助分类器2的损失函数并相加,最后将损失函数反向传播,使用优化器更新参数模型。 

不单独放代码了,不知道哪里是改动的。图片中红色框中是改动的

整个训练集的代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import GoogleNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,shuffle=False, num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0
save_path = './GoogleNet.pth'
# best_acc = 0.0
for epoch in range(2):# trainnet.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)net.eval()acc = 0.0with torch.no_grad():for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_numif accurate_test > best_acc:best_acc = accurate_testtorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

训练完成 

 中间有几次报错,不过在看懂报错后很快改过来了。

二、测试集

载入模型

在创建模型的时候,aux_logits不会构建辅助分类器,但是之前训练的参数会保存。

所以,在载入模型的时候,要设置参数strict=False, 它可以精准匹配当前模型与所需要载入的权重模型的结构。

辅助分类器中的参数全部存放在unexpecte_keys中。

测试集全部代码

 可以自己找图片进行预测看准确率。

import torch
import matplotlib.pyplot as plt
import json
from model import GoogleNet
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("8.jpeg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)# read class_indent
try:json_file = open('./class_indices,json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = GoogleNet(num_classes=5, aux_logits=False)
model_weight_path = "./GoogleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

准确率好低,可能是模型训练的还不够吧。

这篇关于GoogleNet网络训练集和测试集搭建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/914651

相关文章

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解

使用DeepSeek搭建个人知识库(在笔记本电脑上)

《使用DeepSeek搭建个人知识库(在笔记本电脑上)》本文介绍了如何在笔记本电脑上使用DeepSeek和开源工具搭建个人知识库,通过安装DeepSeek和RAGFlow,并使用CherryStudi... 目录部署环境软件清单安装DeepSeek安装Cherry Studio安装RAGFlow设置知识库总

Linux搭建Mysql主从同步的教程

《Linux搭建Mysql主从同步的教程》:本文主要介绍Linux搭建Mysql主从同步的教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux搭建mysql主从同步1.启动mysql服务2.修改Mysql主库配置文件/etc/my.cnf3.重启主库my