pytorch实现自己的深度神经网络(公共数据集)

2024-04-17 02:44

本文主要是介绍pytorch实现自己的深度神经网络(公共数据集),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、训练文件——train.py

  注意:在运行此代码之前,需要配置好pytorch-GPU版本的环境,具体再次不谈。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8,shuffle=True, num_workers=0)# 定义神经网络模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型,并将其移动到可用设备上
model = CNN().to(device)# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)if __name__ == '__main__':# 训练神经网络for epoch in range(5):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 梯度清零optimizer.zero_grad()# 正向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播 + 优化loss.backward()optimizer.step()# 打印统计信息running_loss += loss.item()if i % 200 == 199:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')# 保存模型至文件torch.save(model.state_dict(), 'cifar10_cnn_model.pth')

二、测试文件——val.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10测试数据集
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 创建测试数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8,shuffle=False, num_workers=0)# 加载模型并将其移动到可用设备上
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x
# 显示函数
def imshow(img):img = img / 2 + 0.5npimg = img.numpy()# 坐标转换plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()model = CNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn_model.pth'))
model.eval()if __name__ == '__main__':# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)# 预测值的最大值以及最大值的类别索引_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on the test images: %d %%' % (100 * correct / total))# 显示测试集中的一些图片及其预测结果# 生成一个迭代器,从数据加载器中取出数据dataiter = iter(test_loader)# 从迭代器中获取下一个批次的数据images, labels = dataiter.next()# 将获取到的批次数据移动到device上,在这里也就是GPU上images, labels = images.to(device), labels.to(device)dip_flag = Falseif dip_flag == True:# -------------------------------------------# 可以选择 使用opencv显示# -------------------------------------------np_images = images.cpu().numpy()# 循环遍历并显示所有测试集图片for i in range(len(np_images)):# 从归一化中还原图像数据np_image = np.transpose(np_images[i], (1, 2, 0))   # 从CHW转换为HWCnp_image = np_image * 0.5 + 0.5# 将图像数据从float类型转换为unit8类型np_image = (np_image * 255).astype(np.uint8)# 使用opencv显示图像cv2.imshow("Image {}".format(i+1), np_image)cv2.waitKey(0)# 等待用户按下任意键继续显示下一张图像cv2.destroyAllWindows()imshow(torchvision.utils.make_grid(images.cpu()))print('GroundTruth: ', ' '.join('%5s' % test_dataset.classes[labels[j]] for j in range(8)))outputs = model(images)_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % test_dataset.classes[predicted[j]]for j in range(8)))


直接运行即可,亲测可以运行

这篇关于pytorch实现自己的深度神经网络(公共数据集)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/910589

相关文章

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

java如何分布式锁实现和选型

《java如何分布式锁实现和选型》文章介绍了分布式锁的重要性以及在分布式系统中常见的问题和需求,它详细阐述了如何使用分布式锁来确保数据的一致性和系统的高可用性,文章还提供了基于数据库、Redis和Zo... 目录引言:分布式锁的重要性与分布式系统中的常见问题和需求分布式锁的重要性分布式系统中常见的问题和需求

SpringBoot基于MyBatis-Plus实现Lambda Query查询的示例代码

《SpringBoot基于MyBatis-Plus实现LambdaQuery查询的示例代码》MyBatis-Plus是MyBatis的增强工具,简化了数据库操作,并提高了开发效率,它提供了多种查询方... 目录引言基础环境配置依赖配置(Maven)application.yml 配置表结构设计demo_st

python使用watchdog实现文件资源监控

《python使用watchdog实现文件资源监控》watchdog支持跨平台文件资源监控,可以检测指定文件夹下文件及文件夹变动,下面我们来看看Python如何使用watchdog实现文件资源监控吧... python文件监控库watchdogs简介随着Python在各种应用领域中的广泛使用,其生态环境也

el-select下拉选择缓存的实现

《el-select下拉选择缓存的实现》本文主要介绍了在使用el-select实现下拉选择缓存时遇到的问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录项目场景:问题描述解决方案:项目场景:从左侧列表中选取字段填入右侧下拉多选框,用户可以对右侧

Python pyinstaller实现图形化打包工具

《Pythonpyinstaller实现图形化打包工具》:本文主要介绍一个使用PythonPYQT5制作的关于pyinstaller打包工具,代替传统的cmd黑窗口模式打包页面,实现更快捷方便的... 目录1.简介2.运行效果3.相关源码1.简介一个使用python PYQT5制作的关于pyinstall

使用Python实现大文件切片上传及断点续传的方法

《使用Python实现大文件切片上传及断点续传的方法》本文介绍了使用Python实现大文件切片上传及断点续传的方法,包括功能模块划分(获取上传文件接口状态、临时文件夹状态信息、切片上传、切片合并)、整... 目录概要整体架构流程技术细节获取上传文件状态接口获取临时文件夹状态信息接口切片上传功能文件合并功能小

python实现自动登录12306自动抢票功能

《python实现自动登录12306自动抢票功能》随着互联网技术的发展,越来越多的人选择通过网络平台购票,特别是在中国,12306作为官方火车票预订平台,承担了巨大的访问量,对于热门线路或者节假日出行... 目录一、遇到的问题?二、改进三、进阶–展望总结一、遇到的问题?1.url-正确的表头:就是首先ur

C#实现文件读写到SQLite数据库

《C#实现文件读写到SQLite数据库》这篇文章主要为大家详细介绍了使用C#将文件读写到SQLite数据库的几种方法,文中的示例代码讲解详细,感兴趣的小伙伴可以参考一下... 目录1. 使用 BLOB 存储文件2. 存储文件路径3. 分块存储文件《文件读写到SQLite数据库China编程的方法》博客中,介绍了文