本文主要是介绍ResNet J1,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
🍨 本文为🔗365天深度学习训练营中的学习记录博客(https://mp.weixin.qq.com/s/2Wc0B5c2SdivAR3WS_g1bA)
🍖 原作者:K同学啊|接辅导、项目定制(https://mtyjkh.blog.csdn.net/?type=blog)
本专栏内容: 经典的CNN算法
本专栏创意意义: 在之前的文章中大家对深度学习有了一定的了解,拥有了编写一个完整深度学习程序的能力,本专栏将带大家了解一些经典的算法,大家可以在这个过程中将算法融入YOLOv5当中,是否可以就此改进算法提升准确率呢?
📌 本周任务:
1.请根据本文 TensorFlow 代码,编写出相应的 Pytorch 代码
2.了解残差结构
3.是否可以将残差模块融入到C3当中(自由探索)
理论知识储备
深度残差网络ResNet(deep residual network)在2015年由何凯明等提出,因为它简单与实用并存,随后很多研究都是建立在ResNet-50或者ResNet-101基础上完成的。
ResNet主要解决深度卷积网络在深度加深时候的“退化”问题。 在一般的卷积神经网络中,增大网络深度后带来的第一个问题就是梯度消失、爆炸,这个问题在Szegedy提出BN后被顺利解决。BN层能对各层的输出做归一化,这样梯度在反向层层传递后仍能保持大小稳定,不会出现过小或过大的情况。但是作者发现加了BN后,再加大深度仍然不容易收敛,其提到了第二个问题——准确率下降问题:层级大到一定程度时,准确率就会饱和,然后迅速下降。这种下降既不是梯度消失引起的,也不是过拟合造成的,而是由于网络过于复杂,以至于光靠不加约束的放养式的训练很难达到理想的错误率。准确率下降问题不是网络结构本身的问题,而是现有的训练方式不够理想造成的。当前广泛使用的训练方法,无论是SGD,还是RMSProp,或是Adam,都无法在网络深度变大后达到理论上最优的收敛结果。还可以证明只要有理想的训练方式,更深的网络肯定会比较浅的网络效果要好。证明过程也很简单:假设在一种网络A的后面添加几层形成新的网络B,如果增加的层级只是对A的输出做了个恒等映射(identity mapping),即A的输出经过新增的层级变成B的输出后没有发生变化,这样网络A和网络B的错误率就是相等的,也就证明了加深后的网络不会比加深前的网络效果差。
开发环境
电脑系统:Windows 10
语言环境:Python 3.8.2
编译器:无(直接在cmd.exe内运行)
深度学习环境:Pytorch 1.8.1+cu111
显卡及显存:NVIDIA GeForce GTX 1660 Ti 12G
CUDA版本:Release 10.2, V10.2.89(cmd输入nvcc -V或nvcc --version指令可查看)
YOLOv5开源地址:YOLOv5开源地址()
数据:🔗百度网盘(https://pan.baidu.com/share/init?
前期工作
1.设置GPU
import tensorflow as tfif __name__=='__main__':''' 设置GPU '''gpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]], "GPU")
2023-03-10 19:44:49.081929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/usr/liangjie/soft/netcdf-fortran-4.5.4/release/lib:/home/usr/liangjie/soft/netcdf-c-4.9.0/release/lib:/home/usr/liangjie/soft/zlib-1.2.12/release/lib:/home/usr/liangjie/soft/hdf/HDF5-1.12.2-Linux/HDF_Group/HDF5/1.12.2/lib:/home/usr/liangjie/soft/gdal-3.5.0/release/gdal-release/lib64:
2023-03-10 19:44:49.082031: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-03-10 19:45:17.059985: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/usr/liangjie/soft/netcdf-fortran-4.5.4/release/lib:/home/usr/liangjie/soft/netcdf-c-4.9.0/release/lib:/home/usr/liangjie/soft/zlib-1.2.12/release/lib:/home/usr/liangjie/soft/hdf/HDF5-1.12.2-Linux/HDF_Group/HDF5/1.12.2/lib:/home/usr/liangjie/soft/gdal-3.5.0/release/gdal-release/lib64:
2023-03-10 19:45:17.060023: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2023-03-10 19:45:17.060040: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (ALPHA.ITPCAS.AC.CN): /proc/driver/nvidia/version does not exist
import torch
import torchvisionif __name__=='__main__':''' 设置GPU '''device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("Using {} device\n".format(device))
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.htmlfrom .autonotebook import tqdm as notebook_tqdmUsing cpu device
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号import os,PIL,pathlib
import numpy as npfrom tensorflow import keras
from tensorflow.keras import layers,models
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import torchsummary
import torch.optimimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
device(type='cpu')
2.导入数据
''' 导入数据 '''
data_dir = "./data/bird_photos/"
data_dir = pathlib.Path(data_dir)
root = './data'
output = 'output'
data_dir = os.path.join(root, 'bird_photos')
3.查看数据
tensorFlow
''' 查看数据 '''
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)
Pytorch
''' 读取本地数据集并划分训练集与测试集 '''
def localDataset(data_dir):data_dir = pathlib.Path(data_dir)# 读取本地数据集data_paths = list(data_dir.glob('*'))classeNames = [str(path).split("\\")[-1] for path in data_paths]# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863train_transforms = torchvision.transforms.Compose([torchvision.transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸# torchvision.transforms.RandomHorizontalFlip(), # 随机水平翻转torchvision.transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间torchvision.transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。])total_dataset = torchvision.datasets.ImageFolder(data_dir, transform=train_transforms)print(total_dataset, '\n')print(total_dataset.class_to_idx, '\n')# 划分训练集与测试集train_size = int(0.8 * len(total_dataset))test_size = len(total_dataset) - train_sizeprint('train_size', train_size, 'test_size', test_size, '\n')train_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_size, test_size])return classeNames, train_dataset, test_datasetclasseNames, train_ds, test_ds = localDataset(data_dir)
num_classes = len(classeNames)
print('num_classes', num_classes)
Dataset ImageFolderNumber of datapoints: 565Root location: bird_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3} train_size 452 test_size 113 num_classes 4
''' 读取本地数据集并划分训练集与测试集 '''
def localDataset(data_dir):data_dir = pathlib.Path(data_dir)# 读取本地数据集data_paths = list(data_dir.glob('*'))classeNames = [str(path).split("\\")[-1] for path in data_paths]# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863train_transforms = torchvision.transforms.Compose([torchvision.transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸# torchvision.transforms.RandomHorizontalFlip(), # 随机水平翻转torchvision.transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间torchvision.transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。])total_dataset = torchvision.datasets.ImageFolder(data_dir, transform=train_transforms)print(total_dataset, '\n')print(total_dataset.class_to_idx, '\n')# 划分训练集与测试集train_size = int(0.8 * len(total_dataset))test_size = len(total_dataset) - train_sizeprint('train_size', train_size, 'test_size', test_size, '\n')train_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_size, test_size])return classeNames, train_dataset, test_datasetclasseNames, train_ds, test_ds = localDataset(data_dir)
num_classes = len(classeNames)
print('num_classes', num_classes)
Dataset ImageFolderNumber of datapoints: 565Root location: data/bird_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3} train_size 452 test_size 113 num_classes 4
三、数据预处理
1.加载数据
tensflow
''' 加载数据 '''
batch_size = 8
img_height = 224
img_width = 224
'''
关于image_dataset_from_directory()的详细介绍可以参考文章:
https://mtyjkh.blog.csdn.net/article/details/117018789
'''
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names) # ['Bananaquit', 'Black Throated Bushtiti', 'Black skimmer', 'Cockatoo']
''' 再次检查数据 '''
for image_batch, labels_batch in train_ds:print(image_batch.shape) # (8, 244, 244, 3)print(labels_batch.shape) # (8, )break
''' 配置数据集 '''
AUTOTUNE = tf.data.AUTOTUNE # tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
Found 565 files belonging to 4 classes.
Using 452 files for training.
Found 565 files belonging to 4 classes.
Using 113 files for validation.
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
(8, 224, 224, 3)
(8,)
Pytorch
''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=32, root='', show_flag=False):# 从 train_ds 加载训练集train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 从 test_ds 加载测试集test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 取一个批次查看数据格式# 数据的shape为:[batch_size, channel, height, weight]# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。for X, y in test_dl:print('Shape of X [N, C, H, W]: ', X.shape)print('Shape of y: ', y.shape, y.dtype, '\n')breakimgs, labels = next(iter(train_dl))print('Image shape: ', imgs.shape, '\n')# torch.Size([32, 3, 224, 224]) # 所有数据集中的图像都是224*224的RGB图displayData(imgs, root, show_flag)return train_dl, test_dlbatch_size = 8
train_dl, test_dl = loadData(train_ds, test_ds, batch_size, root, True)
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
batch_size = 8
img_height = 224
img_width = 224
'''
关于image_dataset_from_directory()的详细介绍可以参考文章:
https://mtyjkh.blog.csdn.net/article/details/117018789
'''
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names) # ['Bananaquit', 'Black Throated Bushtiti', 'Black skimmer', 'Cockatoo']
# 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape) # (8, 244, 244, 3)print(labels_batch.shape) # (8, )break
# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE # tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
Found 565 files belonging to 4 classes.
Using 452 files for training.
Found 565 files belonging to 4 classes.
Using 113 files for validation.
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
(8, 224, 224, 3)
(8,)
def loadData(train_ds, test_ds, batch_size=32, root='', show_flag=False):# 从 train_ds 加载训练集train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 从 test_ds 加载测试集test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 取一个批次查看数据格式# 数据的shape为:[batch_size, channel, height, weight]# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。for X, y in test_dl:print('Shape of X [N, C, H, W]: ', X.shape)print('Shape of y: ', y.shape, y.dtype, '\n')breakimgs, labels = next(iter(train_ds))print('Image shape: ', imgs.shape, '\n')# torch.Size([32, 3, 224, 224]) # 所有数据集中的图像都是224*224的RGB图displayData(imgs, root, show_flag)return train_dl, test_dlbatch_size = 8
train_dl, test_dl = loadData(train_ds, test_ds, batch_size, root, True)
train_dl
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64 Image shape: (8, 224, 224, 3)
2.可视化数据
plt.figure(figsize=(10, 5)) # 图形的宽为10高为5
plt.suptitle("微信公众号:K同学啊")for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1) plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 24494 (\N{CJK UNIFIED IDEOGRAPH-5FAE}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20449 (\N{CJK UNIFIED IDEOGRAPH-4FE1}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20844 (\N{CJK UNIFIED IDEOGRAPH-516C}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20247 (\N{CJK UNIFIED IDEOGRAPH-4F17}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21495 (\N{CJK UNIFIED IDEOGRAPH-53F7}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 65306 (\N{FULLWIDTH COLON}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21516 (\N{CJK UNIFIED IDEOGRAPH-540C}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21834 (\N{CJK UNIFIED IDEOGRAPH-554A}) missing from current font.func(*args, **kwargs)
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 24494 (\N{CJK UNIFIED IDEOGRAPH-5FAE}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20449 (\N{CJK UNIFIED IDEOGRAPH-4FE1}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20844 (\N{CJK UNIFIED IDEOGRAPH-516C}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20247 (\N{CJK UNIFIED IDEOGRAPH-4F17}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21495 (\N{CJK UNIFIED IDEOGRAPH-53F7}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 65306 (\N{FULLWIDTH COLON}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21516 (\N{CJK UNIFIED IDEOGRAPH-540C}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21834 (\N{CJK UNIFIED IDEOGRAPH-554A}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
plt.imshow(images[1].numpy().astype("uint8"))
<matplotlib.image.AxesImage at 0x7f57d32d3d30>findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
Pytorch
def displayData(imgs, root='', flag=False):# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)plt.figure('Data Visualization', figsize=(10, 5)) for i, imgs in enumerate(imgs[:8]):print (i,imgs.shape)# 维度顺序调整 [3, 224, 224]->[224, 224, 3]#npimg = imgs.numpy().transpose((1, 2, 0))npimg = imgs.numpy().astype("uint8")# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 4, i+1)plt.imshow(npimg) # cmap=plt.cm.binaryplt.title(list(classeNames)[labels[i]])plt.axis('off')plt.savefig(os.path.join(root, 'DatasetDisplay.png'))if flag:plt.show()else:plt.close('all')
四、残差网络(ResNet)介绍
1.残差网络解决了什么¶
残差网络是为了解决神经网络隐藏层过多时,而引起的网络退化问题。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的
拓展: 深度神经网络的“两朵乌云”
梯度弥散/爆炸
简单来讲就是网络太深了,会导致模型训练难以收敛。这个问题可以被标准初始化和中间层正规化的方法有效控制。(现阶段知道这么一回事就好了)
*网络退化
随着网络深度增加,网络的表现先是逐渐增加至饱和,然后迅速下降,这个退化不是由于过拟合引起的。
评论
2. ResNet-50介绍
ResNet-50有两个基本的块,分别名为Conv Block和Identity Block
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LyvK5e3Q-1678455456179)(attachment:cb78b7f2-d883-46b7-8039-7dd99c060a71.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8UJRPKT7-1678455456179)(attachment:fdce2700-88b8-483f-928c-1bbf40392fd9.png)]
五、构建ResNet-50网络模型
TensorFlow
''' 构建ResNet-50 '''
def identity_block(input_tensor, kernel_size, filters, stage, block):filters1, filters2, filters3 = filtersname_base = str(stage) + block + '_identity_block_'x = Conv2D(filters1, (1, 1), name=name_base+'conv1')(input_tensor)x = BatchNormalization(name=name_base+'bn1')(x)x = Activation('relu', name=name_base+'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base+'conv2')(x)x = BatchNormalization(name=name_base+'bn2')(x)x = Activation('relu', name=name_base+'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base+'conv3')(x)x = BatchNormalization(name=name_base+'bn3')(x)x = layers.add([x, input_tensor], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):filters1, filters2, filters3 = filtersres_name_base = str(stage) + block + '_conv_block_res_'name_base = str(stage) + block + '_conv_block_'x = Conv2D(filters1, (1, 1), strides=strides, name=name_base+'conv1')(input_tensor)x = BatchNormalization(name=name_base+'bn1')(x)x = Activation('relu', name=name_base+'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base+'conv2')(x)x = BatchNormalization(name=name_base+'bn2')(x)x = Activation('relu', name=name_base+'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base+'conv3')(x)x = BatchNormalization(name=name_base+'bn3')(x)shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base+'conv')(input_tensor)shortcut = BatchNormalization(name=res_name_base+'bn')(shortcut)x = layers.add([x, shortcut], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef ResNet50(input_shape=[224, 224, 3], classes=1000):img_input = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(img_input)x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)x = BatchNormalization(name='bn_conv1')(x)x = Activation('relu')(x)x = MaxPooling2D((3, 3), strides=(2, 2))(x)x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')x = AveragePooling2D((7, 7), name='avg_pool')(x)x = Flatten()(x)x = Dense(classes, activation='softmax', name='fc1000')(x)model = Model(img_input, x, name='resnet50')# 加载预训练模型model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")return modelmodel = ResNet50()
model.summary()
Model: "resnet50"
__________________________________________________________________________________________________Layer (type) Output Shape Param # Connected to
==================================================================================================input_3 (InputLayer) [(None, 224, 224, 3 0 [] )] zero_padding2d_2 (ZeroPadding2 (None, 230, 230, 3) 0 ['input_3[0][0]'] D) conv1 (Conv2D) (None, 112, 112, 64 9472 ['zero_padding2d_2[0][0]'] ) bn_conv1 (BatchNormalization) (None, 112, 112, 64 256 ['conv1[0][0]'] ) activation_2 (Activation) (None, 112, 112, 64 0 ['bn_conv1[0][0]'] ) max_pooling2d_2 (MaxPooling2D) (None, 55, 55, 64) 0 ['activation_2[0][0]'] 2a_conv_block_conv1 (Conv2D) (None, 55, 55, 64) 4160 ['max_pooling2d_2[0][0]'] 2a_conv_block_bn1 (BatchNormal (None, 55, 55, 64) 256 ['2a_conv_block_conv1[0][0]'] ization) 2a_conv_block_relu1 (Activatio (None, 55, 55, 64) 0 ['2a_conv_block_bn1[0][0]'] n) 2a_conv_block_conv2 (Conv2D) (None, 55, 55, 64) 36928 ['2a_conv_block_relu1[0][0]'] 2a_conv_block_bn2 (BatchNormal (None, 55, 55, 64) 256 ['2a_conv_block_conv2[0][0]'] ization) 2a_conv_block_relu2 (Activatio (None, 55, 55, 64) 0 ['2a_conv_block_bn2[0][0]'] n) 2a_conv_block_conv3 (Conv2D) (None, 55, 55, 256) 16640 ['2a_conv_block_relu2[0][0]'] 2a_conv_block_res_conv (Conv2D (None, 55, 55, 256) 16640 ['max_pooling2d_2[0][0]'] ) 2a_conv_block_bn3 (BatchNormal (None, 55, 55, 256) 1024 ['2a_conv_block_conv3[0][0]'] ization) 2a_conv_block_res_bn (BatchNor (None, 55, 55, 256) 1024 ['2a_conv_block_res_conv[0][0]'] malization) 2a_conv_block_add (Add) (None, 55, 55, 256) 0 ['2a_conv_block_bn3[0][0]', '2a_conv_block_res_bn[0][0]'] 2a_conv_block_relu4 (Activatio (None, 55, 55, 256) 0 ['2a_conv_block_add[0][0]'] n) 2b_identity_block_conv1 (Conv2 (None, 55, 55, 64) 16448 ['2a_conv_block_relu4[0][0]'] D) 2b_identity_block_bn1 (BatchNo (None, 55, 55, 64) 256 ['2b_identity_block_conv1[0][0]']rmalization) 2b_identity_block_relu1 (Activ (None, 55, 55, 64) 0 ['2b_identity_block_bn1[0][0]'] ation) 2b_identity_block_conv2 (Conv2 (None, 55, 55, 64) 36928 ['2b_identity_block_relu1[0][0]']D) 2b_identity_block_bn2 (BatchNo (None, 55, 55, 64) 256 ['2b_identity_block_conv2[0][0]']rmalization) 2b_identity_block_relu2 (Activ (None, 55, 55, 64) 0 ['2b_identity_block_bn2[0][0]'] ation) 2b_identity_block_conv3 (Conv2 (None, 55, 55, 256) 16640 ['2b_identity_block_relu2[0][0]']D) 2b_identity_block_bn3 (BatchNo (None, 55, 55, 256) 1024 ['2b_identity_block_conv3[0][0]']rmalization) 2b_identity_block_add (Add) (None, 55, 55, 256) 0 ['2b_identity_block_bn3[0][0]', '2a_conv_block_relu4[0][0]'] 2b_identity_block_relu4 (Activ (None, 55, 55, 256) 0 ['2b_identity_block_add[0][0]'] ation) 2c_identity_block_conv1 (Conv2 (None, 55, 55, 64) 16448 ['2b_identity_block_relu4[0][0]']D) 2c_identity_block_bn1 (BatchNo (None, 55, 55, 64) 256 ['2c_identity_block_conv1[0][0]']rmalization) 2c_identity_block_relu1 (Activ (None, 55, 55, 64) 0 ['2c_identity_block_bn1[0][0]'] ation) 2c_identity_block_conv2 (Conv2 (None, 55, 55, 64) 36928 ['2c_identity_block_relu1[0][0]']D) 2c_identity_block_bn2 (BatchNo (None, 55, 55, 64) 256 ['2c_identity_block_conv2[0][0]']rmalization) 2c_identity_block_relu2 (Activ (None, 55, 55, 64) 0 ['2c_identity_block_bn2[0][0]'] ation) 2c_identity_block_conv3 (Conv2 (None, 55, 55, 256) 16640 ['2c_identity_block_relu2[0][0]']D) 2c_identity_block_bn3 (BatchNo (None, 55, 55, 256) 1024 ['2c_identity_block_conv3[0][0]']rmalization) 2c_identity_block_add (Add) (None, 55, 55, 256) 0 ['2c_identity_block_bn3[0][0]', '2b_identity_block_relu4[0][0]']2c_identity_block_relu4 (Activ (None, 55, 55, 256) 0 ['2c_identity_block_add[0][0]'] ation) 3a_conv_block_conv1 (Conv2D) (None, 28, 28, 128) 32896 ['2c_identity_block_relu4[0][0]']3a_conv_block_bn1 (BatchNormal (None, 28, 28, 128) 512 ['3a_conv_block_conv1[0][0]'] ization) 3a_conv_block_relu1 (Activatio (None, 28, 28, 128) 0 ['3a_conv_block_bn1[0][0]'] n) 3a_conv_block_conv2 (Conv2D) (None, 28, 28, 128) 147584 ['3a_conv_block_relu1[0][0]'] 3a_conv_block_bn2 (BatchNormal (None, 28, 28, 128) 512 ['3a_conv_block_conv2[0][0]'] ization) 3a_conv_block_relu2 (Activatio (None, 28, 28, 128) 0 ['3a_conv_block_bn2[0][0]'] n) 3a_conv_block_conv3 (Conv2D) (None, 28, 28, 512) 66048 ['3a_conv_block_relu2[0][0]'] 3a_conv_block_res_conv (Conv2D (None, 28, 28, 512) 131584 ['2c_identity_block_relu4[0][0]']) 3a_conv_block_bn3 (BatchNormal (None, 28, 28, 512) 2048 ['3a_conv_block_conv3[0][0]'] ization) 3a_conv_block_res_bn (BatchNor (None, 28, 28, 512) 2048 ['3a_conv_block_res_conv[0][0]'] malization) 3a_conv_block_add (Add) (None, 28, 28, 512) 0 ['3a_conv_block_bn3[0][0]', '3a_conv_block_res_bn[0][0]'] 3a_conv_block_relu4 (Activatio (None, 28, 28, 512) 0 ['3a_conv_block_add[0][0]'] n) 3b_identity_block_conv1 (Conv2 (None, 28, 28, 128) 65664 ['3a_conv_block_relu4[0][0]'] D) 3b_identity_block_bn1 (BatchNo (None, 28, 28, 128) 512 ['3b_identity_block_conv1[0][0]']rmalization) 3b_identity_block_relu1 (Activ (None, 28, 28, 128) 0 ['3b_identity_block_bn1[0][0]'] ation) 3b_identity_block_conv2 (Conv2 (None, 28, 28, 128) 147584 ['3b_identity_block_relu1[0][0]']D) 3b_identity_block_bn2 (BatchNo (None, 28, 28, 128) 512 ['3b_identity_block_conv2[0][0]']rmalization) 3b_identity_block_relu2 (Activ (None, 28, 28, 128) 0 ['3b_identity_block_bn2[0][0]'] ation) 3b_identity_block_conv3 (Conv2 (None, 28, 28, 512) 66048 ['3b_identity_block_relu2[0][0]']D) 3b_identity_block_bn3 (BatchNo (None, 28, 28, 512) 2048 ['3b_identity_block_conv3[0][0]']rmalization) 3b_identity_block_add (Add) (None, 28, 28, 512) 0 ['3b_identity_block_bn3[0][0]', '3a_conv_block_relu4[0][0]'] 3b_identity_block_relu4 (Activ (None, 28, 28, 512) 0 ['3b_identity_block_add[0][0]'] ation) 3c_identity_block_conv1 (Conv2 (None, 28, 28, 128) 65664 ['3b_identity_block_relu4[0][0]']D) 3c_identity_block_bn1 (BatchNo (None, 28, 28, 128) 512 ['3c_identity_block_conv1[0][0]']rmalization) 3c_identity_block_relu1 (Activ (None, 28, 28, 128) 0 ['3c_identity_block_bn1[0][0]'] ation) 3c_identity_block_conv2 (Conv2 (None, 28, 28, 128) 147584 ['3c_identity_block_relu1[0][0]']D) 3c_identity_block_bn2 (BatchNo (None, 28, 28, 128) 512 ['3c_identity_block_conv2[0][0]']rmalization) 3c_identity_block_relu2 (Activ (None, 28, 28, 128) 0 ['3c_identity_block_bn2[0][0]'] ation) 3c_identity_block_conv3 (Conv2 (None, 28, 28, 512) 66048 ['3c_identity_block_relu2[0][0]']D) 3c_identity_block_bn3 (BatchNo (None, 28, 28, 512) 2048 ['3c_identity_block_conv3[0][0]']rmalization) 3c_identity_block_add (Add) (None, 28, 28, 512) 0 ['3c_identity_block_bn3[0][0]', '3b_identity_block_relu4[0][0]']3c_identity_block_relu4 (Activ (None, 28, 28, 512) 0 ['3c_identity_block_add[0][0]'] ation) 3d_identity_block_conv1 (Conv2 (None, 28, 28, 128) 65664 ['3c_identity_block_relu4[0][0]']D) 3d_identity_block_bn1 (BatchNo (None, 28, 28, 128) 512 ['3d_identity_block_conv1[0][0]']rmalization) 3d_identity_block_relu1 (Activ (None, 28, 28, 128) 0 ['3d_identity_block_bn1[0][0]'] ation) 3d_identity_block_conv2 (Conv2 (None, 28, 28, 128) 147584 ['3d_identity_block_relu1[0][0]']D) 3d_identity_block_bn2 (BatchNo (None, 28, 28, 128) 512 ['3d_identity_block_conv2[0][0]']rmalization) 3d_identity_block_relu2 (Activ (None, 28, 28, 128) 0 ['3d_identity_block_bn2[0][0]'] ation) 3d_identity_block_conv3 (Conv2 (None, 28, 28, 512) 66048 ['3d_identity_block_relu2[0][0]']D) 3d_identity_block_bn3 (BatchNo (None, 28, 28, 512) 2048 ['3d_identity_block_conv3[0][0]']rmalization) 3d_identity_block_add (Add) (None, 28, 28, 512) 0 ['3d_identity_block_bn3[0][0]', '3c_identity_block_relu4[0][0]']3d_identity_block_relu4 (Activ (None, 28, 28, 512) 0 ['3d_identity_block_add[0][0]'] ation) 4a_conv_block_conv1 (Conv2D) (None, 14, 14, 256) 131328 ['3d_identity_block_relu4[0][0]']4a_conv_block_bn1 (BatchNormal (None, 14, 14, 256) 1024 ['4a_conv_block_conv1[0][0]'] ization) 4a_conv_block_relu1 (Activatio (None, 14, 14, 256) 0 ['4a_conv_block_bn1[0][0]'] n) 4a_conv_block_conv2 (Conv2D) (None, 14, 14, 256) 590080 ['4a_conv_block_relu1[0][0]'] 4a_conv_block_bn2 (BatchNormal (None, 14, 14, 256) 1024 ['4a_conv_block_conv2[0][0]'] ization) 4a_conv_block_relu2 (Activatio (None, 14, 14, 256) 0 ['4a_conv_block_bn2[0][0]'] n) 4a_conv_block_conv3 (Conv2D) (None, 14, 14, 1024 263168 ['4a_conv_block_relu2[0][0]'] ) 4a_conv_block_res_conv (Conv2D (None, 14, 14, 1024 525312 ['3d_identity_block_relu4[0][0]']) ) 4a_conv_block_bn3 (BatchNormal (None, 14, 14, 1024 4096 ['4a_conv_block_conv3[0][0]'] ization) ) 4a_conv_block_res_bn (BatchNor (None, 14, 14, 1024 4096 ['4a_conv_block_res_conv[0][0]'] malization) ) 4a_conv_block_add (Add) (None, 14, 14, 1024 0 ['4a_conv_block_bn3[0][0]', ) '4a_conv_block_res_bn[0][0]'] 4a_conv_block_relu4 (Activatio (None, 14, 14, 1024 0 ['4a_conv_block_add[0][0]'] n) ) 4b_identity_block_conv1 (Conv2 (None, 14, 14, 256) 262400 ['4a_conv_block_relu4[0][0]'] D) 4b_identity_block_bn1 (BatchNo (None, 14, 14, 256) 1024 ['4b_identity_block_conv1[0][0]']rmalization) 4b_identity_block_relu1 (Activ (None, 14, 14, 256) 0 ['4b_identity_block_bn1[0][0]'] ation) 4b_identity_block_conv2 (Conv2 (None, 14, 14, 256) 590080 ['4b_identity_block_relu1[0][0]']D) 4b_identity_block_bn2 (BatchNo (None, 14, 14, 256) 1024 ['4b_identity_block_conv2[0][0]']rmalization) 4b_identity_block_relu2 (Activ (None, 14, 14, 256) 0 ['4b_identity_block_bn2[0][0]'] ation) 4b_identity_block_conv3 (Conv2 (None, 14, 14, 1024 263168 ['4b_identity_block_relu2[0][0]']D) ) 4b_identity_block_bn3 (BatchNo (None, 14, 14, 1024 4096 ['4b_identity_block_conv3[0][0]']rmalization) ) 4b_identity_block_add (Add) (None, 14, 14, 1024 0 ['4b_identity_block_bn3[0][0]', ) '4a_conv_block_relu4[0][0]'] 4b_identity_block_relu4 (Activ (None, 14, 14, 1024 0 ['4b_identity_block_add[0][0]'] ation) ) 4c_identity_block_conv1 (Conv2 (None, 14, 14, 256) 262400 ['4b_identity_block_relu4[0][0]']D) 4c_identity_block_bn1 (BatchNo (None, 14, 14, 256) 1024 ['4c_identity_block_conv1[0][0]']rmalization) 4c_identity_block_relu1 (Activ (None, 14, 14, 256) 0 ['4c_identity_block_bn1[0][0]'] ation) 4c_identity_block_conv2 (Conv2 (None, 14, 14, 256) 590080 ['4c_identity_block_relu1[0][0]']D) 4c_identity_block_bn2 (BatchNo (None, 14, 14, 256) 1024 ['4c_identity_block_conv2[0][0]']rmalization) 4c_identity_block_relu2 (Activ (None, 14, 14, 256) 0 ['4c_identity_block_bn2[0][0]'] ation) 4c_identity_block_conv3 (Conv2 (None, 14, 14, 1024 263168 ['4c_identity_block_relu2[0][0]']D) ) 4c_identity_block_bn3 (BatchNo (None, 14, 14, 1024 4096 ['4c_identity_block_conv3[0][0]']rmalization) ) 4c_identity_block_add (Add) (None, 14, 14, 1024 0 ['4c_identity_block_bn3[0][0]', ) '4b_identity_block_relu4[0][0]']4c_identity_block_relu4 (Activ (None, 14, 14, 1024 0 ['4c_identity_block_add[0][0]'] ation) ) 4d_identity_block_conv1 (Conv2 (None, 14, 14, 256) 262400 ['4c_identity_block_relu4[0][0]']D) 4d_identity_block_bn1 (BatchNo (None, 14, 14, 256) 1024 ['4d_identity_block_conv1[0][0]']rmalization) 4d_identity_block_relu1 (Activ (None, 14, 14, 256) 0 ['4d_identity_block_bn1[0][0]'] ation) 4d_identity_block_conv2 (Conv2 (None, 14, 14, 256) 590080 ['4d_identity_block_relu1[0][0]']D) 4d_identity_block_bn2 (BatchNo (None, 14, 14, 256) 1024 ['4d_identity_block_conv2[0][0]']rmalization) 4d_identity_block_relu2 (Activ (None, 14, 14, 256) 0 ['4d_identity_block_bn2[0][0]'] ation) 4d_identity_block_conv3 (Conv2 (None, 14, 14, 1024 263168 ['4d_identity_block_relu2[0][0]']D) ) 4d_identity_block_bn3 (BatchNo (None, 14, 14, 1024 4096 ['4d_identity_block_conv3[0][0]']rmalization) ) 4d_identity_block_add (Add) (None, 14, 14, 1024 0 ['4d_identity_block_bn3[0][0]', ) '4c_identity_block_relu4[0][0]']4d_identity_block_relu4 (Activ (None, 14, 14, 1024 0 ['4d_identity_block_add[0][0]'] ation) ) 4e_identity_block_conv1 (Conv2 (None, 14, 14, 256) 262400 ['4d_identity_block_relu4[0][0]']D) 4e_identity_block_bn1 (BatchNo (None, 14, 14, 256) 1024 ['4e_identity_block_conv1[0][0]']rmalization) 4e_identity_block_relu1 (Activ (None, 14, 14, 256) 0 ['4e_identity_block_bn1[0][0]'] ation) 4e_identity_block_conv2 (Conv2 (None, 14, 14, 256) 590080 ['4e_identity_block_relu1[0][0]']D) 4e_identity_block_bn2 (BatchNo (None, 14, 14, 256) 1024 ['4e_identity_block_conv2[0][0]']rmalization) 4e_identity_block_relu2 (Activ (None, 14, 14, 256) 0 ['4e_identity_block_bn2[0][0]'] ation) 4e_identity_block_conv3 (Conv2 (None, 14, 14, 1024 263168 ['4e_identity_block_relu2[0][0]']D) ) 4e_identity_block_bn3 (BatchNo (None, 14, 14, 1024 4096 ['4e_identity_block_conv3[0][0]']rmalization) ) 4e_identity_block_add (Add) (None, 14, 14, 1024 0 ['4e_identity_block_bn3[0][0]', ) '4d_identity_block_relu4[0][0]']4e_identity_block_relu4 (Activ (None, 14, 14, 1024 0 ['4e_identity_block_add[0][0]'] ation) ) 4f_identity_block_conv1 (Conv2 (None, 14, 14, 256) 262400 ['4e_identity_block_relu4[0][0]']D) 4f_identity_block_bn1 (BatchNo (None, 14, 14, 256) 1024 ['4f_identity_block_conv1[0][0]']rmalization) 4f_identity_block_relu1 (Activ (None, 14, 14, 256) 0 ['4f_identity_block_bn1[0][0]'] ation) 4f_identity_block_conv2 (Conv2 (None, 14, 14, 256) 590080 ['4f_identity_block_relu1[0][0]']D) 4f_identity_block_bn2 (BatchNo (None, 14, 14, 256) 1024 ['4f_identity_block_conv2[0][0]']rmalization) 4f_identity_block_relu2 (Activ (None, 14, 14, 256) 0 ['4f_identity_block_bn2[0][0]'] ation) 4f_identity_block_conv3 (Conv2 (None, 14, 14, 1024 263168 ['4f_identity_block_relu2[0][0]']D) ) 4f_identity_block_bn3 (BatchNo (None, 14, 14, 1024 4096 ['4f_identity_block_conv3[0][0]']rmalization) ) 4f_identity_block_add (Add) (None, 14, 14, 1024 0 ['4f_identity_block_bn3[0][0]', ) '4e_identity_block_relu4[0][0]']4f_identity_block_relu4 (Activ (None, 14, 14, 1024 0 ['4f_identity_block_add[0][0]'] ation) ) 5a_conv_block_conv1 (Conv2D) (None, 7, 7, 512) 524800 ['4f_identity_block_relu4[0][0]']5a_conv_block_bn1 (BatchNormal (None, 7, 7, 512) 2048 ['5a_conv_block_conv1[0][0]'] ization) 5a_conv_block_relu1 (Activatio (None, 7, 7, 512) 0 ['5a_conv_block_bn1[0][0]'] n) 5a_conv_block_conv2 (Conv2D) (None, 7, 7, 512) 2359808 ['5a_conv_block_relu1[0][0]'] 5a_conv_block_bn2 (BatchNormal (None, 7, 7, 512) 2048 ['5a_conv_block_conv2[0][0]'] ization) 5a_conv_block_relu2 (Activatio (None, 7, 7, 512) 0 ['5a_conv_block_bn2[0][0]'] n) 5a_conv_block_conv3 (Conv2D) (None, 7, 7, 2048) 1050624 ['5a_conv_block_relu2[0][0]'] 5a_conv_block_res_conv (Conv2D (None, 7, 7, 2048) 2099200 ['4f_identity_block_relu4[0][0]']) 5a_conv_block_bn3 (BatchNormal (None, 7, 7, 2048) 8192 ['5a_conv_block_conv3[0][0]'] ization) 5a_conv_block_res_bn (BatchNor (None, 7, 7, 2048) 8192 ['5a_conv_block_res_conv[0][0]'] malization) 5a_conv_block_add (Add) (None, 7, 7, 2048) 0 ['5a_conv_block_bn3[0][0]', '5a_conv_block_res_bn[0][0]'] 5a_conv_block_relu4 (Activatio (None, 7, 7, 2048) 0 ['5a_conv_block_add[0][0]'] n) 5b_identity_block_conv1 (Conv2 (None, 7, 7, 512) 1049088 ['5a_conv_block_relu4[0][0]'] D) 5b_identity_block_bn1 (BatchNo (None, 7, 7, 512) 2048 ['5b_identity_block_conv1[0][0]']rmalization) 5b_identity_block_relu1 (Activ (None, 7, 7, 512) 0 ['5b_identity_block_bn1[0][0]'] ation) 5b_identity_block_conv2 (Conv2 (None, 7, 7, 512) 2359808 ['5b_identity_block_relu1[0][0]']D) 5b_identity_block_bn2 (BatchNo (None, 7, 7, 512) 2048 ['5b_identity_block_conv2[0][0]']rmalization) 5b_identity_block_relu2 (Activ (None, 7, 7, 512) 0 ['5b_identity_block_bn2[0][0]'] ation) 5b_identity_block_conv3 (Conv2 (None, 7, 7, 2048) 1050624 ['5b_identity_block_relu2[0][0]']D) 5b_identity_block_bn3 (BatchNo (None, 7, 7, 2048) 8192 ['5b_identity_block_conv3[0][0]']rmalization) 5b_identity_block_add (Add) (None, 7, 7, 2048) 0 ['5b_identity_block_bn3[0][0]', '5a_conv_block_relu4[0][0]'] 5b_identity_block_relu4 (Activ (None, 7, 7, 2048) 0 ['5b_identity_block_add[0][0]'] ation) 5c_identity_block_conv1 (Conv2 (None, 7, 7, 512) 1049088 ['5b_identity_block_relu4[0][0]']D) 5c_identity_block_bn1 (BatchNo (None, 7, 7, 512) 2048 ['5c_identity_block_conv1[0][0]']rmalization) 5c_identity_block_relu1 (Activ (None, 7, 7, 512) 0 ['5c_identity_block_bn1[0][0]'] ation) 5c_identity_block_conv2 (Conv2 (None, 7, 7, 512) 2359808 ['5c_identity_block_relu1[0][0]']D) 5c_identity_block_bn2 (BatchNo (None, 7, 7, 512) 2048 ['5c_identity_block_conv2[0][0]']rmalization) 5c_identity_block_relu2 (Activ (None, 7, 7, 512) 0 ['5c_identity_block_bn2[0][0]'] ation) 5c_identity_block_conv3 (Conv2 (None, 7, 7, 2048) 1050624 ['5c_identity_block_relu2[0][0]']D) 5c_identity_block_bn3 (BatchNo (None, 7, 7, 2048) 8192 ['5c_identity_block_conv3[0][0]']rmalization) 5c_identity_block_add (Add) (None, 7, 7, 2048) 0 ['5c_identity_block_bn3[0][0]', '5b_identity_block_relu4[0][0]']5c_identity_block_relu4 (Activ (None, 7, 7, 2048) 0 ['5c_identity_block_add[0][0]'] ation) avg_pool (AveragePooling2D) (None, 1, 1, 2048) 0 ['5c_identity_block_relu4[0][0]']flatten_2 (Flatten) (None, 2048) 0 ['avg_pool[0][0]'] fc1000 (Dense) (None, 1000) 2049000 ['flatten_2[0][0]'] ==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
pytorch
n_class = 4
''' Same Padding '''
def autopad(k, p=None): # kernel, padding# Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn p''' Identity Block '''
class IdentityBlock(nn.Module):def __init__(self, in_channel, kernel_size, filters):super(IdentityBlock, self).__init__()filters1, filters2, filters3 = filtersself.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters1, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters1),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters1, filters2, kernel_size, stride=1, padding=autopad(kernel_size), bias=False),nn.BatchNorm2d(filters2),nn.ReLU(True))self.conv3 = nn.Sequential(nn.Conv2d(filters2, filters3, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters3))self.relu = nn.ReLU(True)def forward(self, x):x1 = self.conv1(x)x1 = self.conv2(x1)x1 = self.conv3(x1)x = x1 + xself.relu(x)return x''' Conv Block '''
class ConvBlock(nn.Module):def __init__(self, in_channel, kernel_size, filters, stride=2):super(ConvBlock, self).__init__()filters1, filters2, filters3 = filtersself.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters1, 1, stride=stride, padding=0, bias=False),nn.BatchNorm2d(filters1),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters1, filters2, kernel_size, stride=1, padding=autopad(kernel_size), bias=False),nn.BatchNorm2d(filters2),nn.ReLU(True))self.conv3 = nn.Sequential(nn.Conv2d(filters2, filters3, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters3))self.conv4 = nn.Sequential(nn.Conv2d(in_channel, filters3, 1, stride=stride, padding=0, bias=False),nn.BatchNorm2d(filters3))self.relu = nn.ReLU(True)def forward(self, x):x1 = self.conv1(x)x1 = self.conv2(x1)x1 = self.conv3(x1)x2 = self.conv4(x)x = x1 + x2self.relu(x)return x''' 构建ResNet-50 '''
class ResNet50(nn.Module):def __init__(self, classes=1000):super(ResNet50, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=0))self.conv2 = nn.Sequential(ConvBlock(64, 3, [64, 64, 256], stride=1),IdentityBlock(256, 3, [64, 64, 256]),IdentityBlock(256, 3, [64, 64, 256]))self.conv3 = nn.Sequential(ConvBlock(256, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]))self.conv4 = nn.Sequential(ConvBlock(512, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]))self.conv5 = nn.Sequential(ConvBlock(1024, 3, [512, 512, 2048]),IdentityBlock(2048, 3, [512, 512, 2048]),IdentityBlock(2048, 3, [512, 512, 2048]))self.pool = nn.AvgPool2d(kernel_size=7, stride=7, padding=0)self.fc = nn.Linear(2048, n_class)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.pool(x)x = torch.flatten(x, start_dim=1)x = self.fc(x)return xmodel = ResNet50().to(device)
''' 显示网络结构 '''
torchsummary.summary(model, (3, 224, 224))
#torchinfo.summary(model)
print(model)
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 112, 112] 9,408BatchNorm2d-2 [-1, 64, 112, 112] 128ReLU-3 [-1, 64, 112, 112] 0MaxPool2d-4 [-1, 64, 55, 55] 0Conv2d-5 [-1, 64, 55, 55] 4,096BatchNorm2d-6 [-1, 64, 55, 55] 128ReLU-7 [-1, 64, 55, 55] 0Conv2d-8 [-1, 64, 55, 55] 36,864BatchNorm2d-9 [-1, 64, 55, 55] 128ReLU-10 [-1, 64, 55, 55] 0Conv2d-11 [-1, 256, 55, 55] 16,384BatchNorm2d-12 [-1, 256, 55, 55] 512Conv2d-13 [-1, 256, 55, 55] 16,384BatchNorm2d-14 [-1, 256, 55, 55] 512ReLU-15 [-1, 256, 55, 55] 0ConvBlock-16 [-1, 256, 55, 55] 0Conv2d-17 [-1, 64, 55, 55] 16,384BatchNorm2d-18 [-1, 64, 55, 55] 128ReLU-19 [-1, 64, 55, 55] 0Conv2d-20 [-1, 64, 55, 55] 36,864BatchNorm2d-21 [-1, 64, 55, 55] 128ReLU-22 [-1, 64, 55, 55] 0Conv2d-23 [-1, 256, 55, 55] 16,384BatchNorm2d-24 [-1, 256, 55, 55] 512ReLU-25 [-1, 256, 55, 55] 0IdentityBlock-26 [-1, 256, 55, 55] 0Conv2d-27 [-1, 64, 55, 55] 16,384BatchNorm2d-28 [-1, 64, 55, 55] 128ReLU-29 [-1, 64, 55, 55] 0Conv2d-30 [-1, 64, 55, 55] 36,864BatchNorm2d-31 [-1, 64, 55, 55] 128ReLU-32 [-1, 64, 55, 55] 0Conv2d-33 [-1, 256, 55, 55] 16,384BatchNorm2d-34 [-1, 256, 55, 55] 512ReLU-35 [-1, 256, 55, 55] 0IdentityBlock-36 [-1, 256, 55, 55] 0Conv2d-37 [-1, 128, 28, 28] 32,768BatchNorm2d-38 [-1, 128, 28, 28] 256ReLU-39 [-1, 128, 28, 28] 0Conv2d-40 [-1, 128, 28, 28] 147,456BatchNorm2d-41 [-1, 128, 28, 28] 256ReLU-42 [-1, 128, 28, 28] 0Conv2d-43 [-1, 512, 28, 28] 65,536BatchNorm2d-44 [-1, 512, 28, 28] 1,024Conv2d-45 [-1, 512, 28, 28] 131,072BatchNorm2d-46 [-1, 512, 28, 28] 1,024ReLU-47 [-1, 512, 28, 28] 0ConvBlock-48 [-1, 512, 28, 28] 0Conv2d-49 [-1, 128, 28, 28] 65,536BatchNorm2d-50 [-1, 128, 28, 28] 256ReLU-51 [-1, 128, 28, 28] 0Conv2d-52 [-1, 128, 28, 28] 147,456BatchNorm2d-53 [-1, 128, 28, 28] 256ReLU-54 [-1, 128, 28, 28] 0Conv2d-55 [-1, 512, 28, 28] 65,536BatchNorm2d-56 [-1, 512, 28, 28] 1,024ReLU-57 [-1, 512, 28, 28] 0IdentityBlock-58 [-1, 512, 28, 28] 0Conv2d-59 [-1, 128, 28, 28] 65,536BatchNorm2d-60 [-1, 128, 28, 28] 256ReLU-61 [-1, 128, 28, 28] 0Conv2d-62 [-1, 128, 28, 28] 147,456BatchNorm2d-63 [-1, 128, 28, 28] 256ReLU-64 [-1, 128, 28, 28] 0Conv2d-65 [-1, 512, 28, 28] 65,536BatchNorm2d-66 [-1, 512, 28, 28] 1,024ReLU-67 [-1, 512, 28, 28] 0IdentityBlock-68 [-1, 512, 28, 28] 0Conv2d-69 [-1, 128, 28, 28] 65,536BatchNorm2d-70 [-1, 128, 28, 28] 256ReLU-71 [-1, 128, 28, 28] 0Conv2d-72 [-1, 128, 28, 28] 147,456BatchNorm2d-73 [-1, 128, 28, 28] 256ReLU-74 [-1, 128, 28, 28] 0Conv2d-75 [-1, 512, 28, 28] 65,536BatchNorm2d-76 [-1, 512, 28, 28] 1,024ReLU-77 [-1, 512, 28, 28] 0IdentityBlock-78 [-1, 512, 28, 28] 0Conv2d-79 [-1, 256, 14, 14] 131,072BatchNorm2d-80 [-1, 256, 14, 14] 512ReLU-81 [-1, 256, 14, 14] 0Conv2d-82 [-1, 256, 14, 14] 589,824BatchNorm2d-83 [-1, 256, 14, 14] 512ReLU-84 [-1, 256, 14, 14] 0Conv2d-85 [-1, 1024, 14, 14] 262,144BatchNorm2d-86 [-1, 1024, 14, 14] 2,048Conv2d-87 [-1, 1024, 14, 14] 524,288BatchNorm2d-88 [-1, 1024, 14, 14] 2,048ReLU-89 [-1, 1024, 14, 14] 0ConvBlock-90 [-1, 1024, 14, 14] 0Conv2d-91 [-1, 256, 14, 14] 262,144BatchNorm2d-92 [-1, 256, 14, 14] 512ReLU-93 [-1, 256, 14, 14] 0Conv2d-94 [-1, 256, 14, 14] 589,824BatchNorm2d-95 [-1, 256, 14, 14] 512ReLU-96 [-1, 256, 14, 14] 0Conv2d-97 [-1, 1024, 14, 14] 262,144BatchNorm2d-98 [-1, 1024, 14, 14] 2,048ReLU-99 [-1, 1024, 14, 14] 0IdentityBlock-100 [-1, 1024, 14, 14] 0Conv2d-101 [-1, 256, 14, 14] 262,144BatchNorm2d-102 [-1, 256, 14, 14] 512ReLU-103 [-1, 256, 14, 14] 0Conv2d-104 [-1, 256, 14, 14] 589,824BatchNorm2d-105 [-1, 256, 14, 14] 512ReLU-106 [-1, 256, 14, 14] 0Conv2d-107 [-1, 1024, 14, 14] 262,144BatchNorm2d-108 [-1, 1024, 14, 14] 2,048ReLU-109 [-1, 1024, 14, 14] 0IdentityBlock-110 [-1, 1024, 14, 14] 0Conv2d-111 [-1, 256, 14, 14] 262,144BatchNorm2d-112 [-1, 256, 14, 14] 512ReLU-113 [-1, 256, 14, 14] 0Conv2d-114 [-1, 256, 14, 14] 589,824BatchNorm2d-115 [-1, 256, 14, 14] 512ReLU-116 [-1, 256, 14, 14] 0Conv2d-117 [-1, 1024, 14, 14] 262,144BatchNorm2d-118 [-1, 1024, 14, 14] 2,048ReLU-119 [-1, 1024, 14, 14] 0IdentityBlock-120 [-1, 1024, 14, 14] 0Conv2d-121 [-1, 256, 14, 14] 262,144BatchNorm2d-122 [-1, 256, 14, 14] 512ReLU-123 [-1, 256, 14, 14] 0Conv2d-124 [-1, 256, 14, 14] 589,824BatchNorm2d-125 [-1, 256, 14, 14] 512ReLU-126 [-1, 256, 14, 14] 0Conv2d-127 [-1, 1024, 14, 14] 262,144BatchNorm2d-128 [-1, 1024, 14, 14] 2,048ReLU-129 [-1, 1024, 14, 14] 0IdentityBlock-130 [-1, 1024, 14, 14] 0Conv2d-131 [-1, 256, 14, 14] 262,144BatchNorm2d-132 [-1, 256, 14, 14] 512ReLU-133 [-1, 256, 14, 14] 0Conv2d-134 [-1, 256, 14, 14] 589,824BatchNorm2d-135 [-1, 256, 14, 14] 512ReLU-136 [-1, 256, 14, 14] 0Conv2d-137 [-1, 1024, 14, 14] 262,144BatchNorm2d-138 [-1, 1024, 14, 14] 2,048ReLU-139 [-1, 1024, 14, 14] 0IdentityBlock-140 [-1, 1024, 14, 14] 0Conv2d-141 [-1, 512, 7, 7] 524,288BatchNorm2d-142 [-1, 512, 7, 7] 1,024ReLU-143 [-1, 512, 7, 7] 0Conv2d-144 [-1, 512, 7, 7] 2,359,296BatchNorm2d-145 [-1, 512, 7, 7] 1,024ReLU-146 [-1, 512, 7, 7] 0Conv2d-147 [-1, 2048, 7, 7] 1,048,576BatchNorm2d-148 [-1, 2048, 7, 7] 4,096Conv2d-149 [-1, 2048, 7, 7] 2,097,152BatchNorm2d-150 [-1, 2048, 7, 7] 4,096ReLU-151 [-1, 2048, 7, 7] 0ConvBlock-152 [-1, 2048, 7, 7] 0Conv2d-153 [-1, 512, 7, 7] 1,048,576BatchNorm2d-154 [-1, 512, 7, 7] 1,024ReLU-155 [-1, 512, 7, 7] 0Conv2d-156 [-1, 512, 7, 7] 2,359,296BatchNorm2d-157 [-1, 512, 7, 7] 1,024ReLU-158 [-1, 512, 7, 7] 0Conv2d-159 [-1, 2048, 7, 7] 1,048,576BatchNorm2d-160 [-1, 2048, 7, 7] 4,096ReLU-161 [-1, 2048, 7, 7] 0IdentityBlock-162 [-1, 2048, 7, 7] 0Conv2d-163 [-1, 512, 7, 7] 1,048,576BatchNorm2d-164 [-1, 512, 7, 7] 1,024ReLU-165 [-1, 512, 7, 7] 0Conv2d-166 [-1, 512, 7, 7] 2,359,296BatchNorm2d-167 [-1, 512, 7, 7] 1,024ReLU-168 [-1, 512, 7, 7] 0Conv2d-169 [-1, 2048, 7, 7] 1,048,576BatchNorm2d-170 [-1, 2048, 7, 7] 4,096ReLU-171 [-1, 2048, 7, 7] 0IdentityBlock-172 [-1, 2048, 7, 7] 0AvgPool2d-173 [-1, 2048, 1, 1] 0Linear-174 [-1, 4] 8,196
================================================================
Total params: 23,516,228
Trainable params: 23,516,228
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 270.43
Params size (MB): 89.71
Estimated Total Size (MB): 360.71
----------------------------------------------------------------
ResNet50((conv1): Sequential((0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()(3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))(conv2): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv3): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(3): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv4): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(3): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(4): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(5): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv5): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fc): Linear(in_features=2048, out_features=4, bias=True)
)
六、编译
在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
损失函数(loss):用于衡量模型在训练期间的准确率。
优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
TensorFlow
# 设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#要训练什么参数/
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.92)#学习率每5个epoch衰减成原来的1/10
loss_fn = nn.CrossEntropyLoss()
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小,一共900张图片num_batches = len(dataloader) # 批次数目,29(900/32)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小,一共10000张图片num_batches = len(dataloader) # 批次数目,8(255/32=8,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_ds, model, loss_fn, optimizer)scheduler.step()#学习率衰减model.eval()epoch_test_acc, epoch_test_loss = test(test_ds, model, loss_fn)# 保存最优模型if epoch_test_acc > best_acc:best_acc = epoch_train_accstate = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)
# 设置优化器,我这里改变了学习率。
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 设置优化器,我这里改变了学习率。
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])
''' 设置超参数 '''
start_epoch = 0
epochs = 10
learn_rate = 1e-7 # 初始学习率
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
optimizer = torch.optim.Adam(model.parameters(),lr=learn_rate)
train_loss = []
train_acc = []
test_loss = []
test_acc = []
epoch_best_acc = 0
七、训练模型
TensorFlow
epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
pytoch
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)scheduler.step()#学习率衰减model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最优模型if epoch_test_acc > best_acc:best_acc = epoch_train_accstate = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)
可视化数据
tensorFlow
''' 可视化数据 '''
plt.figure(fgsize=(10, 5))
plt.suptitle("微信公众号:K同学啊")
for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")plt.imshow(images[1].numpy().astype("uint8"))
Pytorch
''' 数据可视化 '''
def displayData(imgs, root='', flag=False):# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)plt.figure('Data Visualization', figsize=(10, 5)) for i, imgs in enumerate(imgs[:8]):# 维度顺序调整 [3, 224, 224]->[224, 224, 3]npimg = imgs.numpy().transpose((1, 2, 0))# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 4, i+1)plt.imshow(npimg) # cmap=plt.cm.binaryplt.title(list(classeNames)[labels[i]])plt.axis('off')plt.savefig(os.path.join(root, 'DatasetDisplay.png'))if flag:plt.show()else:plt.close('all')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nQYECgKD-1678455456179)(attachment:9d4957ce-8541-4129-9cca-c291835e5af9.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9kJ1EZGD-1678455456180)(attachment:656a4e3e-17f9-4a59-bb78-0d3980063c32.png)]
八、模型评估
''' 模型评估 '''
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle('K同学啊')
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training adn Validation Loss')
plt.show()
这篇关于ResNet J1的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!