本文主要是介绍基于Pytorch肺部感染识别案例(采用ResNet网络结构),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一、整体流程
1. 数据集下载地址:https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia/download
2. 数据集展示
案例主要流程:
第一步:加载预训练模型ResNet,该模型已在ImageNet上训练过。
第二步:冻结预训练模型中低层卷积层的参数(权重)。
第三步:用可训练参数的多层替换分类层。
第四步:在训练集上训练分类层。
第五步:微调超参数,根据需要解冻更多层。
ResNet 网络结构图
二、显示图片功能
#1加载库
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import os
from torchvision.utils import make_gridfrom torch.utils.data import DataLoader
#2、定义一个方法:显示图片
def img_show(inp, title=None):plt.figure(figsize=(14,3))inp = inp.numpy().transpose((1,2,0)) #转成numpy,然后转置mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224,0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)plt.show()
def main():pass#3、定义超参数BATCH_SIZE = 8DEVICE = torch.device("gpu" if torch.cuda.is_available() else "cpu")#4、图片转换 使用字典进行转换data_transforms = {'train': transforms.Compose([transforms.Resize(300),transforms.RandomResizedCrop(300) ,#随机裁剪transforms.RandomHorizontalFlip(),transforms.CenterCrop(256),transforms.ToTensor(), #转为张量transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) #正则化]),'val': transforms.Compose([transforms.Resize(300),transforms.CenterCrop(256),transforms.ToTensor(), #转为张量transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) #正则化])}#5、操作数据集# 5.1、数据集路径data_path = "D:/chest_xray/"#5.2、加载数据集的train valimg_datasets = { x : datasets.ImageFolder(os.path.join(data_path,x),data_transforms[x]) for x in ["train","val"]}#5.3、为数据集创建一个迭代器,读取数据dataloaders = {x : DataLoader(img_datasets[x], shuffle=True,batch_size= BATCH_SIZE) for x in ["train","val"]}# 5.4、训练集和验证集的大小(图片的数量)data_sizes = {x : len(img_datasets[x]) for x in ["train","val"]}# 5.5、获取标签类别名称 NORMAL 正常 -- PNEUMONIA 感染target_names = img_datasets['train'].classes#6 显示一个batch_size 的图片(8张图片)#6.1 读取8张图片datas ,targets = next(iter(dataloaders['train'])) #iter把对象变为可迭代对象,next去迭代#6.2、将若干正图片平成一副图像out = make_grid(datas, norm = 4, padding = 10)
这篇关于基于Pytorch肺部感染识别案例(采用ResNet网络结构)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!