PyTorch下的5种不同神经网络-ResNet

2024-06-21 22:28

本文主要是介绍PyTorch下的5种不同神经网络-ResNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.导入模块

导入所需的Python库,包括图像处理、深度学习模型和数据加载

import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms

2.定义自定义图像数据集:

创建一个自定义的图像数据集类,用于加载和处理图像数据

class CustomImageDataset(Dataset):def __init__(self, main_dir, transform=None):self.main_dir = main_dirself.transform = transformself.files = []self.labels = []self.label_to_index = {}for index, label in enumerate(os.listdir(main_dir)):self.label_to_index[label] = indexlabel_dir = os.path.join(main_dir, label)if os.path.isdir(label_dir):for file in os.listdir(label_dir):self.files.append(os.path.join(label_dir, file))self.labels.append(label)def __len__(self):return len(self.files)def __getitem__(self, idx):image = Image.open(self.files[idx])label = self.labels[idx]if self.transform:image = self.transform(image)return image, self.label_to_index[label]

3.定义数据转换

定义一个数据转换过程,包括图像大小调整、转换为张量以及标准化

transform = transforms.Compose([transforms.Resize((224, 224)),  # ResNet的输入图像大小transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 标准化])

4.创建数据集

使用自定义数据集类和定义的数据转换来创建数据集

dataset = CustomImageDataset(main_dir="F:\\A-GX\\A-SJJ\\flower_photos\\flower_photos", transform=transform)

5.创建数据加载器

使用数据集创建一个数据加载器,用于批量加载和处理数据。

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

6.加载预训练的ResNet模型

从PyTorch库中加载预训练的ResNet18模型

resnet_model = models.resnet18(pretrained=True)

7.修改最后几层以适应新的分类任务

修改ResNet模型的最后几层,以便它能够处理新的分类任务

num_ftrs = resnet_model.fc.in_featuresresnet_model.fc = nn.Linear(num_ftrs, len(dataset.label_to_index))

8.定义损失函数和优化器

定义用于训练模型的损失函数和优化器

criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)

9.模型并行化

如果有多GPU,则使用nn.DataParallel来并行化模型

if torch.cuda.device_count() > 1:resnet_model = nn.DataParallel(resnet_model)

10.将模型发送到GPU

模型发送到GPU进行训练

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")resnet_model.to(device)

11.训练模型

使用数据加载器和定义的参数训练模型

num_epochs = 10for epoch in range(num_epochs):resnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = resnet_model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}')

这篇关于PyTorch下的5种不同神经网络-ResNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1082512

相关文章

2. c#从不同cs的文件调用函数

1.文件目录如下: 2. Program.cs文件的主函数如下 using System;using System.Collections.Generic;using System.Linq;using System.Threading.Tasks;using System.Windows.Forms;namespace datasAnalysis{internal static

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

uva 10061 How many zero's and how many digits ?(不同进制阶乘末尾几个0)+poj 1401

题意是求在base进制下的 n!的结果有几位数,末尾有几个0。 想起刚开始的时候做的一道10进制下的n阶乘末尾有几个零,以及之前有做过的一道n阶乘的位数。 当时都是在10进制下的。 10进制下的做法是: 1. n阶位数:直接 lg(n!)就是得数的位数。 2. n阶末尾0的个数:由于2 * 5 将会在得数中以0的形式存在,所以计算2或者计算5,由于因子中出现5必然出现2,所以直接一

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

速了解MySQL 数据库不同存储引擎

快速了解MySQL 数据库不同存储引擎 MySQL 提供了多种存储引擎,每种存储引擎都有其特定的特性和适用场景。了解这些存储引擎的特性,有助于在设计数据库时做出合理的选择。以下是 MySQL 中几种常用存储引擎的详细介绍。 1. InnoDB 特点: 事务支持:InnoDB 是一个支持 ACID(原子性、一致性、隔离性、持久性)事务的存储引擎。行级锁:使用行级锁来提高并发性,减少锁竞争

MyBatis 切换不同的类型数据库方案

下属案例例当前结合SpringBoot 配置进行讲解。 背景: 实现一个工程里面在部署阶段支持切换不同类型数据库支持。 方案一 数据源配置 关键代码(是什么数据库,该怎么配就怎么配) spring:datasource:name: test# 使用druid数据源type: com.alibaba.druid.pool.DruidDataSource# @需要修改 数据库连接及驱动u

linux中使用rust语言在不同进程之间通信

第一种:使用mmap映射相同文件 fn main() {let pid = std::process::id();println!(

机器学习之监督学习(三)神经网络

机器学习之监督学习(三)神经网络基础 0. 文章传送1. 深度学习 Deep Learning深度学习的关键特点深度学习VS传统机器学习 2. 生物神经网络 Biological Neural Network3. 神经网络模型基本结构模块一:TensorFlow搭建神经网络 4. 反向传播梯度下降 Back Propagation Gradient Descent模块二:激活函数 activ

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04