本文主要是介绍PyTorch下的5种不同神经网络-ResNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1.导入模块
导入所需的Python库,包括图像处理、深度学习模型和数据加载
import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms
2.定义自定义图像数据集:
创建一个自定义的图像数据集类,用于加载和处理图像数据
class CustomImageDataset(Dataset):def __init__(self, main_dir, transform=None):self.main_dir = main_dirself.transform = transformself.files = []self.labels = []self.label_to_index = {}for index, label in enumerate(os.listdir(main_dir)):self.label_to_index[label] = indexlabel_dir = os.path.join(main_dir, label)if os.path.isdir(label_dir):for file in os.listdir(label_dir):self.files.append(os.path.join(label_dir, file))self.labels.append(label)def __len__(self):return len(self.files)def __getitem__(self, idx):image = Image.open(self.files[idx])label = self.labels[idx]if self.transform:image = self.transform(image)return image, self.label_to_index[label]
3.定义数据转换
定义一个数据转换过程,包括图像大小调整、转换为张量以及标准化
transform = transforms.Compose([transforms.Resize((224, 224)), # ResNet的输入图像大小transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化])
4.创建数据集
使用自定义数据集类和定义的数据转换来创建数据集
dataset = CustomImageDataset(main_dir="F:\\A-GX\\A-SJJ\\flower_photos\\flower_photos", transform=transform)
5.创建数据加载器
使用数据集创建一个数据加载器,用于批量加载和处理数据。
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
6.加载预训练的ResNet模型
从PyTorch库中加载预训练的ResNet18模型
resnet_model = models.resnet18(pretrained=True)
7.修改最后几层以适应新的分类任务
修改ResNet模型的最后几层,以便它能够处理新的分类任务
num_ftrs = resnet_model.fc.in_featuresresnet_model.fc = nn.Linear(num_ftrs, len(dataset.label_to_index))
8.定义损失函数和优化器
定义用于训练模型的损失函数和优化器
criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)
9.模型并行化
如果有多GPU,则使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:resnet_model = nn.DataParallel(resnet_model)
10.将模型发送到GPU
模型发送到GPU进行训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")resnet_model.to(device)
11.训练模型
使用数据加载器和定义的参数训练模型
num_epochs = 10for epoch in range(num_epochs):resnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = resnet_model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}')
这篇关于PyTorch下的5种不同神经网络-ResNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!