ResNet J1

2023-10-12 05:59
文章标签 resnet j1

本文主要是介绍ResNet J1,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🍨 本文为🔗365天深度学习训练营中的学习记录博客(https://mp.weixin.qq.com/s/2Wc0B5c2SdivAR3WS_g1bA)
🍖 原作者:K同学啊|接辅导、项目定制(https://mtyjkh.blog.csdn.net/?type=blog)

本专栏内容: 经典的CNN算法
本专栏创意意义: 在之前的文章中大家对深度学习有了一定的了解,拥有了编写一个完整深度学习程序的能力,本专栏将带大家了解一些经典的算法,大家可以在这个过程中将算法融入YOLOv5当中,是否可以就此改进算法提升准确率呢?

📌 本周任务:

1.请根据本文 TensorFlow 代码,编写出相应的 Pytorch 代码
2.了解残差结构
3.是否可以将残差模块融入到C3当中(自由探索)

理论知识储备

深度残差网络ResNet(deep residual network)在2015年由何凯明等提出,因为它简单与实用并存,随后很多研究都是建立在ResNet-50或者ResNet-101基础上完成的。

ResNet主要解决深度卷积网络在深度加深时候的“退化”问题。 在一般的卷积神经网络中,增大网络深度后带来的第一个问题就是梯度消失、爆炸,这个问题在Szegedy提出BN后被顺利解决。BN层能对各层的输出做归一化,这样梯度在反向层层传递后仍能保持大小稳定,不会出现过小或过大的情况。但是作者发现加了BN后,再加大深度仍然不容易收敛,其提到了第二个问题——准确率下降问题:层级大到一定程度时,准确率就会饱和,然后迅速下降。这种下降既不是梯度消失引起的,也不是过拟合造成的,而是由于网络过于复杂,以至于光靠不加约束的放养式的训练很难达到理想的错误率。准确率下降问题不是网络结构本身的问题,而是现有的训练方式不够理想造成的。当前广泛使用的训练方法,无论是SGD,还是RMSProp,或是Adam,都无法在网络深度变大后达到理论上最优的收敛结果。还可以证明只要有理想的训练方式,更深的网络肯定会比较浅的网络效果要好。证明过程也很简单:假设在一种网络A的后面添加几层形成新的网络B,如果增加的层级只是对A的输出做了个恒等映射(identity mapping),即A的输出经过新增的层级变成B的输出后没有发生变化,这样网络A和网络B的错误率就是相等的,也就证明了加深后的网络不会比加深前的网络效果差。
开发环境
电脑系统:Windows 10
语言环境:Python 3.8.2
编译器:无(直接在cmd.exe内运行)
深度学习环境:Pytorch 1.8.1+cu111
显卡及显存:NVIDIA GeForce GTX 1660 Ti 12G
CUDA版本:Release 10.2, V10.2.89(cmd输入nvcc -V或nvcc --version指令可查看)
YOLOv5开源地址:YOLOv5开源地址()
数据:🔗百度网盘(https://pan.baidu.com/share/init?surl=LfiEQ43HhSsm7qU3KOM0Bg)https://pan.baidu.com/share/init?surl=LfiEQ43HhSsm7qU3KOM0Bg)

在这里插入图片描述


前期工作

1.设置GPU

import tensorflow as tfif __name__=='__main__':''' 设置GPU '''gpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]], "GPU")
2023-03-10 19:44:49.081929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/usr/liangjie/soft/netcdf-fortran-4.5.4/release/lib:/home/usr/liangjie/soft/netcdf-c-4.9.0/release/lib:/home/usr/liangjie/soft/zlib-1.2.12/release/lib:/home/usr/liangjie/soft/hdf/HDF5-1.12.2-Linux/HDF_Group/HDF5/1.12.2/lib:/home/usr/liangjie/soft/gdal-3.5.0/release/gdal-release/lib64:
2023-03-10 19:44:49.082031: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-03-10 19:45:17.059985: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/usr/liangjie/soft/netcdf-fortran-4.5.4/release/lib:/home/usr/liangjie/soft/netcdf-c-4.9.0/release/lib:/home/usr/liangjie/soft/zlib-1.2.12/release/lib:/home/usr/liangjie/soft/hdf/HDF5-1.12.2-Linux/HDF_Group/HDF5/1.12.2/lib:/home/usr/liangjie/soft/gdal-3.5.0/release/gdal-release/lib64:
2023-03-10 19:45:17.060023: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2023-03-10 19:45:17.060040: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (ALPHA.ITPCAS.AC.CN): /proc/driver/nvidia/version does not exist
import torch
import torchvisionif __name__=='__main__':''' 设置GPU '''device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("Using {} device\n".format(device))
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.htmlfrom .autonotebook import tqdm as notebook_tqdmUsing cpu device
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL,pathlib
import numpy as npfrom tensorflow import keras
from tensorflow.keras import layers,models
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import torchsummary
import torch.optimimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
device(type='cpu')

2.导入数据


''' 导入数据 '''
data_dir = "./data/bird_photos/"
data_dir = pathlib.Path(data_dir)
root = './data'
output = 'output'
data_dir = os.path.join(root, 'bird_photos')

3.查看数据

tensorFlow

''' 查看数据 '''
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

Pytorch

''' 读取本地数据集并划分训练集与测试集 '''
def localDataset(data_dir):data_dir = pathlib.Path(data_dir)# 读取本地数据集data_paths = list(data_dir.glob('*'))classeNames = [str(path).split("\\")[-1] for path in data_paths]# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863train_transforms = torchvision.transforms.Compose([torchvision.transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# torchvision.transforms.RandomHorizontalFlip(), # 随机水平翻转torchvision.transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间torchvision.transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。])total_dataset = torchvision.datasets.ImageFolder(data_dir, transform=train_transforms)print(total_dataset, '\n')print(total_dataset.class_to_idx, '\n')# 划分训练集与测试集train_size = int(0.8 * len(total_dataset))test_size  = len(total_dataset) - train_sizeprint('train_size', train_size, 'test_size', test_size, '\n')train_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_size, test_size])return classeNames, train_dataset, test_datasetclasseNames, train_ds, test_ds = localDataset(data_dir)
num_classes = len(classeNames)
print('num_classes', num_classes)
Dataset ImageFolderNumber of datapoints: 565Root location: bird_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3} train_size 452 test_size 113 num_classes 4
''' 读取本地数据集并划分训练集与测试集 '''
def localDataset(data_dir):data_dir = pathlib.Path(data_dir)# 读取本地数据集data_paths = list(data_dir.glob('*'))classeNames = [str(path).split("\\")[-1] for path in data_paths]# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863train_transforms = torchvision.transforms.Compose([torchvision.transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# torchvision.transforms.RandomHorizontalFlip(), # 随机水平翻转torchvision.transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间torchvision.transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。])total_dataset = torchvision.datasets.ImageFolder(data_dir, transform=train_transforms)print(total_dataset, '\n')print(total_dataset.class_to_idx, '\n')# 划分训练集与测试集train_size = int(0.8 * len(total_dataset))test_size  = len(total_dataset) - train_sizeprint('train_size', train_size, 'test_size', test_size, '\n')train_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_size, test_size])return classeNames, train_dataset, test_datasetclasseNames, train_ds, test_ds = localDataset(data_dir)
num_classes = len(classeNames)
print('num_classes', num_classes)
Dataset ImageFolderNumber of datapoints: 565Root location: data/bird_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3} train_size 452 test_size 113 num_classes 4

三、数据预处理

1.加载数据

tensflow

''' 加载数据 '''
batch_size = 8
img_height = 224
img_width = 224
'''
关于image_dataset_from_directory()的详细介绍可以参考文章:
https://mtyjkh.blog.csdn.net/article/details/117018789
'''
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)  # ['Bananaquit', 'Black Throated Bushtiti', 'Black skimmer', 'Cockatoo']
''' 再次检查数据 '''
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (8, 244, 244, 3)print(labels_batch.shape)  # (8, )break
''' 配置数据集 '''
AUTOTUNE = tf.data.AUTOTUNE  # tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
Found 565 files belonging to 4 classes.
Using 452 files for training.
Found 565 files belonging to 4 classes.
Using 113 files for validation.
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
(8, 224, 224, 3)
(8,)

Pytorch

''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=32, root='', show_flag=False):# 从 train_ds 加载训练集train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 从 test_ds 加载测试集test_dl  = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 取一个批次查看数据格式# 数据的shape为:[batch_size, channel, height, weight]# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。for X, y in test_dl:print('Shape of X [N, C, H, W]: ', X.shape)print('Shape of y: ', y.shape, y.dtype, '\n')breakimgs, labels = next(iter(train_dl))print('Image shape: ', imgs.shape, '\n')# torch.Size([32, 3, 224, 224])  # 所有数据集中的图像都是224*224的RGB图displayData(imgs, root, show_flag)return train_dl, test_dlbatch_size = 8
train_dl, test_dl = loadData(train_ds, test_ds, batch_size, root, True)
Shape of X [N, C, H, W]:  torch.Size([8, 3, 224, 224])
Shape of y:  torch.Size([8]) torch.int64 
batch_size = 8
img_height = 224
img_width = 224
'''
关于image_dataset_from_directory()的详细介绍可以参考文章:
https://mtyjkh.blog.csdn.net/article/details/117018789
'''
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)  # ['Bananaquit', 'Black Throated Bushtiti', 'Black skimmer', 'Cockatoo']
#  再次检查数据 
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (8, 244, 244, 3)print(labels_batch.shape)  # (8, )break
# 配置数据集 
AUTOTUNE = tf.data.AUTOTUNE  # tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
Found 565 files belonging to 4 classes.
Using 452 files for training.
Found 565 files belonging to 4 classes.
Using 113 files for validation.
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
(8, 224, 224, 3)
(8,)
def loadData(train_ds, test_ds, batch_size=32, root='', show_flag=False):# 从 train_ds 加载训练集train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 从 test_ds 加载测试集test_dl  = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,num_workers=1)# 取一个批次查看数据格式# 数据的shape为:[batch_size, channel, height, weight]# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。for X, y in test_dl:print('Shape of X [N, C, H, W]: ', X.shape)print('Shape of y: ', y.shape, y.dtype, '\n')breakimgs, labels = next(iter(train_ds))print('Image shape: ', imgs.shape, '\n')# torch.Size([32, 3, 224, 224])  # 所有数据集中的图像都是224*224的RGB图displayData(imgs, root, show_flag)return train_dl, test_dlbatch_size = 8
train_dl, test_dl = loadData(train_ds, test_ds, batch_size, root, True)
train_dl
Shape of X [N, C, H, W]:  torch.Size([8, 3, 224, 224])
Shape of y:  torch.Size([8]) torch.int64 Image shape:  (8, 224, 224, 3) 

2.可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("微信公众号:K同学啊")for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 24494 (\N{CJK UNIFIED IDEOGRAPH-5FAE}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20449 (\N{CJK UNIFIED IDEOGRAPH-4FE1}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20844 (\N{CJK UNIFIED IDEOGRAPH-516C}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 20247 (\N{CJK UNIFIED IDEOGRAPH-4F17}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21495 (\N{CJK UNIFIED IDEOGRAPH-53F7}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 65306 (\N{FULLWIDTH COLON}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21516 (\N{CJK UNIFIED IDEOGRAPH-540C}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from current font.func(*args, **kwargs)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/events.py:89: UserWarning: Glyph 21834 (\N{CJK UNIFIED IDEOGRAPH-554A}) missing from current font.func(*args, **kwargs)
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 24494 (\N{CJK UNIFIED IDEOGRAPH-5FAE}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20449 (\N{CJK UNIFIED IDEOGRAPH-4FE1}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20844 (\N{CJK UNIFIED IDEOGRAPH-516C}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20247 (\N{CJK UNIFIED IDEOGRAPH-4F17}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21495 (\N{CJK UNIFIED IDEOGRAPH-53F7}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 65306 (\N{FULLWIDTH COLON}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21516 (\N{CJK UNIFIED IDEOGRAPH-540C}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 21834 (\N{CJK UNIFIED IDEOGRAPH-554A}) missing from current font.fig.canvas.print_figure(bytes_io, **kw)
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BW41q6kM-1678455456178)(output_30_1.png)]

plt.imshow(images[1].numpy().astype("uint8"))
<matplotlib.image.AxesImage at 0x7f57d32d3d30>findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qEQcXgoi-1678455456178)(output_31_2.png)]

Pytorch

def displayData(imgs, root='', flag=False):# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)plt.figure('Data Visualization', figsize=(10, 5)) for i, imgs in enumerate(imgs[:8]):print (i,imgs.shape)# 维度顺序调整 [3, 224, 224]->[224, 224, 3]#npimg = imgs.numpy().transpose((1, 2, 0))npimg = imgs.numpy().astype("uint8")# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 4, i+1)plt.imshow(npimg)  # cmap=plt.cm.binaryplt.title(list(classeNames)[labels[i]])plt.axis('off')plt.savefig(os.path.join(root, 'DatasetDisplay.png'))if flag:plt.show()else:plt.close('all')

四、残差网络(ResNet)介绍

1.残差网络解决了什么¶
残差网络是为了解决神经网络隐藏层过多时,而引起的网络退化问题。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的

拓展: 深度神经网络的“两朵乌云”

梯度弥散/爆炸
简单来讲就是网络太深了,会导致模型训练难以收敛。这个问题可以被标准初始化和中间层正规化的方法有效控制。(现阶段知道这么一回事就好了)

*网络退化

随着网络深度增加,网络的表现先是逐渐增加至饱和,然后迅速下降,这个退化不是由于过拟合引起的。

评论
2. ResNet-50介绍
ResNet-50有两个基本的块,分别名为Conv Block和Identity Block
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LyvK5e3Q-1678455456179)(attachment:cb78b7f2-d883-46b7-8039-7dd99c060a71.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8UJRPKT7-1678455456179)(attachment:fdce2700-88b8-483f-928c-1bbf40392fd9.png)]

五、构建ResNet-50网络模型

TensorFlow

''' 构建ResNet-50 '''
def identity_block(input_tensor, kernel_size, filters, stage, block):filters1, filters2, filters3 = filtersname_base = str(stage) + block + '_identity_block_'x = Conv2D(filters1, (1, 1), name=name_base+'conv1')(input_tensor)x = BatchNormalization(name=name_base+'bn1')(x)x = Activation('relu', name=name_base+'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base+'conv2')(x)x = BatchNormalization(name=name_base+'bn2')(x)x = Activation('relu', name=name_base+'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base+'conv3')(x)x = BatchNormalization(name=name_base+'bn3')(x)x = layers.add([x, input_tensor], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):filters1, filters2, filters3 = filtersres_name_base = str(stage) + block + '_conv_block_res_'name_base = str(stage) + block + '_conv_block_'x = Conv2D(filters1, (1, 1), strides=strides, name=name_base+'conv1')(input_tensor)x = BatchNormalization(name=name_base+'bn1')(x)x = Activation('relu', name=name_base+'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base+'conv2')(x)x = BatchNormalization(name=name_base+'bn2')(x)x = Activation('relu', name=name_base+'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base+'conv3')(x)x = BatchNormalization(name=name_base+'bn3')(x)shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base+'conv')(input_tensor)shortcut = BatchNormalization(name=res_name_base+'bn')(shortcut)x = layers.add([x, shortcut], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef ResNet50(input_shape=[224, 224, 3], classes=1000):img_input = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(img_input)x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)x = BatchNormalization(name='bn_conv1')(x)x = Activation('relu')(x)x = MaxPooling2D((3, 3), strides=(2, 2))(x)x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')x = AveragePooling2D((7, 7), name='avg_pool')(x)x = Flatten()(x)x = Dense(classes, activation='softmax', name='fc1000')(x)model = Model(img_input, x, name='resnet50')# 加载预训练模型model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")return modelmodel = ResNet50()
model.summary()
Model: "resnet50"
__________________________________________________________________________________________________Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================input_3 (InputLayer)           [(None, 224, 224, 3  0           []                               )]                                                                zero_padding2d_2 (ZeroPadding2  (None, 230, 230, 3)  0          ['input_3[0][0]']                D)                                                                                               conv1 (Conv2D)                 (None, 112, 112, 64  9472        ['zero_padding2d_2[0][0]']       )                                                                 bn_conv1 (BatchNormalization)  (None, 112, 112, 64  256         ['conv1[0][0]']                  )                                                                 activation_2 (Activation)      (None, 112, 112, 64  0           ['bn_conv1[0][0]']               )                                                                 max_pooling2d_2 (MaxPooling2D)  (None, 55, 55, 64)  0           ['activation_2[0][0]']           2a_conv_block_conv1 (Conv2D)   (None, 55, 55, 64)   4160        ['max_pooling2d_2[0][0]']        2a_conv_block_bn1 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv1[0][0]']    ization)                                                                                         2a_conv_block_relu1 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn1[0][0]']      n)                                                                                               2a_conv_block_conv2 (Conv2D)   (None, 55, 55, 64)   36928       ['2a_conv_block_relu1[0][0]']    2a_conv_block_bn2 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv2[0][0]']    ization)                                                                                         2a_conv_block_relu2 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn2[0][0]']      n)                                                                                               2a_conv_block_conv3 (Conv2D)   (None, 55, 55, 256)  16640       ['2a_conv_block_relu2[0][0]']    2a_conv_block_res_conv (Conv2D  (None, 55, 55, 256)  16640      ['max_pooling2d_2[0][0]']        )                                                                                                2a_conv_block_bn3 (BatchNormal  (None, 55, 55, 256)  1024       ['2a_conv_block_conv3[0][0]']    ization)                                                                                         2a_conv_block_res_bn (BatchNor  (None, 55, 55, 256)  1024       ['2a_conv_block_res_conv[0][0]'] malization)                                                                                      2a_conv_block_add (Add)        (None, 55, 55, 256)  0           ['2a_conv_block_bn3[0][0]',      '2a_conv_block_res_bn[0][0]']   2a_conv_block_relu4 (Activatio  (None, 55, 55, 256)  0          ['2a_conv_block_add[0][0]']      n)                                                                                               2b_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2a_conv_block_relu4[0][0]']    D)                                                                                               2b_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv1[0][0]']rmalization)                                                                                     2b_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn1[0][0]']  ation)                                                                                           2b_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2b_identity_block_relu1[0][0]']D)                                                                                               2b_identity_block_bn2 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv2[0][0]']rmalization)                                                                                     2b_identity_block_relu2 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn2[0][0]']  ation)                                                                                           2b_identity_block_conv3 (Conv2  (None, 55, 55, 256)  16640      ['2b_identity_block_relu2[0][0]']D)                                                                                               2b_identity_block_bn3 (BatchNo  (None, 55, 55, 256)  1024       ['2b_identity_block_conv3[0][0]']rmalization)                                                                                     2b_identity_block_add (Add)    (None, 55, 55, 256)  0           ['2b_identity_block_bn3[0][0]',  '2a_conv_block_relu4[0][0]']    2b_identity_block_relu4 (Activ  (None, 55, 55, 256)  0          ['2b_identity_block_add[0][0]']  ation)                                                                                           2c_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2b_identity_block_relu4[0][0]']D)                                                                                               2c_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2c_identity_block_conv1[0][0]']rmalization)                                                                                     2c_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2c_identity_block_bn1[0][0]']  ation)                                                                                           2c_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2c_identity_block_relu1[0][0]']D)                                                                                               2c_identity_block_bn2 (BatchNo  (None, 55, 55, 64)  256         ['2c_identity_block_conv2[0][0]']rmalization)                                                                                     2c_identity_block_relu2 (Activ  (None, 55, 55, 64)  0           ['2c_identity_block_bn2[0][0]']  ation)                                                                                           2c_identity_block_conv3 (Conv2  (None, 55, 55, 256)  16640      ['2c_identity_block_relu2[0][0]']D)                                                                                               2c_identity_block_bn3 (BatchNo  (None, 55, 55, 256)  1024       ['2c_identity_block_conv3[0][0]']rmalization)                                                                                     2c_identity_block_add (Add)    (None, 55, 55, 256)  0           ['2c_identity_block_bn3[0][0]',  '2b_identity_block_relu4[0][0]']2c_identity_block_relu4 (Activ  (None, 55, 55, 256)  0          ['2c_identity_block_add[0][0]']  ation)                                                                                           3a_conv_block_conv1 (Conv2D)   (None, 28, 28, 128)  32896       ['2c_identity_block_relu4[0][0]']3a_conv_block_bn1 (BatchNormal  (None, 28, 28, 128)  512        ['3a_conv_block_conv1[0][0]']    ization)                                                                                         3a_conv_block_relu1 (Activatio  (None, 28, 28, 128)  0          ['3a_conv_block_bn1[0][0]']      n)                                                                                               3a_conv_block_conv2 (Conv2D)   (None, 28, 28, 128)  147584      ['3a_conv_block_relu1[0][0]']    3a_conv_block_bn2 (BatchNormal  (None, 28, 28, 128)  512        ['3a_conv_block_conv2[0][0]']    ization)                                                                                         3a_conv_block_relu2 (Activatio  (None, 28, 28, 128)  0          ['3a_conv_block_bn2[0][0]']      n)                                                                                               3a_conv_block_conv3 (Conv2D)   (None, 28, 28, 512)  66048       ['3a_conv_block_relu2[0][0]']    3a_conv_block_res_conv (Conv2D  (None, 28, 28, 512)  131584     ['2c_identity_block_relu4[0][0]'])                                                                                                3a_conv_block_bn3 (BatchNormal  (None, 28, 28, 512)  2048       ['3a_conv_block_conv3[0][0]']    ization)                                                                                         3a_conv_block_res_bn (BatchNor  (None, 28, 28, 512)  2048       ['3a_conv_block_res_conv[0][0]'] malization)                                                                                      3a_conv_block_add (Add)        (None, 28, 28, 512)  0           ['3a_conv_block_bn3[0][0]',      '3a_conv_block_res_bn[0][0]']   3a_conv_block_relu4 (Activatio  (None, 28, 28, 512)  0          ['3a_conv_block_add[0][0]']      n)                                                                                               3b_identity_block_conv1 (Conv2  (None, 28, 28, 128)  65664      ['3a_conv_block_relu4[0][0]']    D)                                                                                               3b_identity_block_bn1 (BatchNo  (None, 28, 28, 128)  512        ['3b_identity_block_conv1[0][0]']rmalization)                                                                                     3b_identity_block_relu1 (Activ  (None, 28, 28, 128)  0          ['3b_identity_block_bn1[0][0]']  ation)                                                                                           3b_identity_block_conv2 (Conv2  (None, 28, 28, 128)  147584     ['3b_identity_block_relu1[0][0]']D)                                                                                               3b_identity_block_bn2 (BatchNo  (None, 28, 28, 128)  512        ['3b_identity_block_conv2[0][0]']rmalization)                                                                                     3b_identity_block_relu2 (Activ  (None, 28, 28, 128)  0          ['3b_identity_block_bn2[0][0]']  ation)                                                                                           3b_identity_block_conv3 (Conv2  (None, 28, 28, 512)  66048      ['3b_identity_block_relu2[0][0]']D)                                                                                               3b_identity_block_bn3 (BatchNo  (None, 28, 28, 512)  2048       ['3b_identity_block_conv3[0][0]']rmalization)                                                                                     3b_identity_block_add (Add)    (None, 28, 28, 512)  0           ['3b_identity_block_bn3[0][0]',  '3a_conv_block_relu4[0][0]']    3b_identity_block_relu4 (Activ  (None, 28, 28, 512)  0          ['3b_identity_block_add[0][0]']  ation)                                                                                           3c_identity_block_conv1 (Conv2  (None, 28, 28, 128)  65664      ['3b_identity_block_relu4[0][0]']D)                                                                                               3c_identity_block_bn1 (BatchNo  (None, 28, 28, 128)  512        ['3c_identity_block_conv1[0][0]']rmalization)                                                                                     3c_identity_block_relu1 (Activ  (None, 28, 28, 128)  0          ['3c_identity_block_bn1[0][0]']  ation)                                                                                           3c_identity_block_conv2 (Conv2  (None, 28, 28, 128)  147584     ['3c_identity_block_relu1[0][0]']D)                                                                                               3c_identity_block_bn2 (BatchNo  (None, 28, 28, 128)  512        ['3c_identity_block_conv2[0][0]']rmalization)                                                                                     3c_identity_block_relu2 (Activ  (None, 28, 28, 128)  0          ['3c_identity_block_bn2[0][0]']  ation)                                                                                           3c_identity_block_conv3 (Conv2  (None, 28, 28, 512)  66048      ['3c_identity_block_relu2[0][0]']D)                                                                                               3c_identity_block_bn3 (BatchNo  (None, 28, 28, 512)  2048       ['3c_identity_block_conv3[0][0]']rmalization)                                                                                     3c_identity_block_add (Add)    (None, 28, 28, 512)  0           ['3c_identity_block_bn3[0][0]',  '3b_identity_block_relu4[0][0]']3c_identity_block_relu4 (Activ  (None, 28, 28, 512)  0          ['3c_identity_block_add[0][0]']  ation)                                                                                           3d_identity_block_conv1 (Conv2  (None, 28, 28, 128)  65664      ['3c_identity_block_relu4[0][0]']D)                                                                                               3d_identity_block_bn1 (BatchNo  (None, 28, 28, 128)  512        ['3d_identity_block_conv1[0][0]']rmalization)                                                                                     3d_identity_block_relu1 (Activ  (None, 28, 28, 128)  0          ['3d_identity_block_bn1[0][0]']  ation)                                                                                           3d_identity_block_conv2 (Conv2  (None, 28, 28, 128)  147584     ['3d_identity_block_relu1[0][0]']D)                                                                                               3d_identity_block_bn2 (BatchNo  (None, 28, 28, 128)  512        ['3d_identity_block_conv2[0][0]']rmalization)                                                                                     3d_identity_block_relu2 (Activ  (None, 28, 28, 128)  0          ['3d_identity_block_bn2[0][0]']  ation)                                                                                           3d_identity_block_conv3 (Conv2  (None, 28, 28, 512)  66048      ['3d_identity_block_relu2[0][0]']D)                                                                                               3d_identity_block_bn3 (BatchNo  (None, 28, 28, 512)  2048       ['3d_identity_block_conv3[0][0]']rmalization)                                                                                     3d_identity_block_add (Add)    (None, 28, 28, 512)  0           ['3d_identity_block_bn3[0][0]',  '3c_identity_block_relu4[0][0]']3d_identity_block_relu4 (Activ  (None, 28, 28, 512)  0          ['3d_identity_block_add[0][0]']  ation)                                                                                           4a_conv_block_conv1 (Conv2D)   (None, 14, 14, 256)  131328      ['3d_identity_block_relu4[0][0]']4a_conv_block_bn1 (BatchNormal  (None, 14, 14, 256)  1024       ['4a_conv_block_conv1[0][0]']    ization)                                                                                         4a_conv_block_relu1 (Activatio  (None, 14, 14, 256)  0          ['4a_conv_block_bn1[0][0]']      n)                                                                                               4a_conv_block_conv2 (Conv2D)   (None, 14, 14, 256)  590080      ['4a_conv_block_relu1[0][0]']    4a_conv_block_bn2 (BatchNormal  (None, 14, 14, 256)  1024       ['4a_conv_block_conv2[0][0]']    ization)                                                                                         4a_conv_block_relu2 (Activatio  (None, 14, 14, 256)  0          ['4a_conv_block_bn2[0][0]']      n)                                                                                               4a_conv_block_conv3 (Conv2D)   (None, 14, 14, 1024  263168      ['4a_conv_block_relu2[0][0]']    )                                                                 4a_conv_block_res_conv (Conv2D  (None, 14, 14, 1024  525312     ['3d_identity_block_relu4[0][0]'])                              )                                                                 4a_conv_block_bn3 (BatchNormal  (None, 14, 14, 1024  4096       ['4a_conv_block_conv3[0][0]']    ization)                       )                                                                 4a_conv_block_res_bn (BatchNor  (None, 14, 14, 1024  4096       ['4a_conv_block_res_conv[0][0]'] malization)                    )                                                                 4a_conv_block_add (Add)        (None, 14, 14, 1024  0           ['4a_conv_block_bn3[0][0]',      )                                 '4a_conv_block_res_bn[0][0]']   4a_conv_block_relu4 (Activatio  (None, 14, 14, 1024  0          ['4a_conv_block_add[0][0]']      n)                             )                                                                 4b_identity_block_conv1 (Conv2  (None, 14, 14, 256)  262400     ['4a_conv_block_relu4[0][0]']    D)                                                                                               4b_identity_block_bn1 (BatchNo  (None, 14, 14, 256)  1024       ['4b_identity_block_conv1[0][0]']rmalization)                                                                                     4b_identity_block_relu1 (Activ  (None, 14, 14, 256)  0          ['4b_identity_block_bn1[0][0]']  ation)                                                                                           4b_identity_block_conv2 (Conv2  (None, 14, 14, 256)  590080     ['4b_identity_block_relu1[0][0]']D)                                                                                               4b_identity_block_bn2 (BatchNo  (None, 14, 14, 256)  1024       ['4b_identity_block_conv2[0][0]']rmalization)                                                                                     4b_identity_block_relu2 (Activ  (None, 14, 14, 256)  0          ['4b_identity_block_bn2[0][0]']  ation)                                                                                           4b_identity_block_conv3 (Conv2  (None, 14, 14, 1024  263168     ['4b_identity_block_relu2[0][0]']D)                             )                                                                 4b_identity_block_bn3 (BatchNo  (None, 14, 14, 1024  4096       ['4b_identity_block_conv3[0][0]']rmalization)                   )                                                                 4b_identity_block_add (Add)    (None, 14, 14, 1024  0           ['4b_identity_block_bn3[0][0]',  )                                 '4a_conv_block_relu4[0][0]']    4b_identity_block_relu4 (Activ  (None, 14, 14, 1024  0          ['4b_identity_block_add[0][0]']  ation)                         )                                                                 4c_identity_block_conv1 (Conv2  (None, 14, 14, 256)  262400     ['4b_identity_block_relu4[0][0]']D)                                                                                               4c_identity_block_bn1 (BatchNo  (None, 14, 14, 256)  1024       ['4c_identity_block_conv1[0][0]']rmalization)                                                                                     4c_identity_block_relu1 (Activ  (None, 14, 14, 256)  0          ['4c_identity_block_bn1[0][0]']  ation)                                                                                           4c_identity_block_conv2 (Conv2  (None, 14, 14, 256)  590080     ['4c_identity_block_relu1[0][0]']D)                                                                                               4c_identity_block_bn2 (BatchNo  (None, 14, 14, 256)  1024       ['4c_identity_block_conv2[0][0]']rmalization)                                                                                     4c_identity_block_relu2 (Activ  (None, 14, 14, 256)  0          ['4c_identity_block_bn2[0][0]']  ation)                                                                                           4c_identity_block_conv3 (Conv2  (None, 14, 14, 1024  263168     ['4c_identity_block_relu2[0][0]']D)                             )                                                                 4c_identity_block_bn3 (BatchNo  (None, 14, 14, 1024  4096       ['4c_identity_block_conv3[0][0]']rmalization)                   )                                                                 4c_identity_block_add (Add)    (None, 14, 14, 1024  0           ['4c_identity_block_bn3[0][0]',  )                                 '4b_identity_block_relu4[0][0]']4c_identity_block_relu4 (Activ  (None, 14, 14, 1024  0          ['4c_identity_block_add[0][0]']  ation)                         )                                                                 4d_identity_block_conv1 (Conv2  (None, 14, 14, 256)  262400     ['4c_identity_block_relu4[0][0]']D)                                                                                               4d_identity_block_bn1 (BatchNo  (None, 14, 14, 256)  1024       ['4d_identity_block_conv1[0][0]']rmalization)                                                                                     4d_identity_block_relu1 (Activ  (None, 14, 14, 256)  0          ['4d_identity_block_bn1[0][0]']  ation)                                                                                           4d_identity_block_conv2 (Conv2  (None, 14, 14, 256)  590080     ['4d_identity_block_relu1[0][0]']D)                                                                                               4d_identity_block_bn2 (BatchNo  (None, 14, 14, 256)  1024       ['4d_identity_block_conv2[0][0]']rmalization)                                                                                     4d_identity_block_relu2 (Activ  (None, 14, 14, 256)  0          ['4d_identity_block_bn2[0][0]']  ation)                                                                                           4d_identity_block_conv3 (Conv2  (None, 14, 14, 1024  263168     ['4d_identity_block_relu2[0][0]']D)                             )                                                                 4d_identity_block_bn3 (BatchNo  (None, 14, 14, 1024  4096       ['4d_identity_block_conv3[0][0]']rmalization)                   )                                                                 4d_identity_block_add (Add)    (None, 14, 14, 1024  0           ['4d_identity_block_bn3[0][0]',  )                                 '4c_identity_block_relu4[0][0]']4d_identity_block_relu4 (Activ  (None, 14, 14, 1024  0          ['4d_identity_block_add[0][0]']  ation)                         )                                                                 4e_identity_block_conv1 (Conv2  (None, 14, 14, 256)  262400     ['4d_identity_block_relu4[0][0]']D)                                                                                               4e_identity_block_bn1 (BatchNo  (None, 14, 14, 256)  1024       ['4e_identity_block_conv1[0][0]']rmalization)                                                                                     4e_identity_block_relu1 (Activ  (None, 14, 14, 256)  0          ['4e_identity_block_bn1[0][0]']  ation)                                                                                           4e_identity_block_conv2 (Conv2  (None, 14, 14, 256)  590080     ['4e_identity_block_relu1[0][0]']D)                                                                                               4e_identity_block_bn2 (BatchNo  (None, 14, 14, 256)  1024       ['4e_identity_block_conv2[0][0]']rmalization)                                                                                     4e_identity_block_relu2 (Activ  (None, 14, 14, 256)  0          ['4e_identity_block_bn2[0][0]']  ation)                                                                                           4e_identity_block_conv3 (Conv2  (None, 14, 14, 1024  263168     ['4e_identity_block_relu2[0][0]']D)                             )                                                                 4e_identity_block_bn3 (BatchNo  (None, 14, 14, 1024  4096       ['4e_identity_block_conv3[0][0]']rmalization)                   )                                                                 4e_identity_block_add (Add)    (None, 14, 14, 1024  0           ['4e_identity_block_bn3[0][0]',  )                                 '4d_identity_block_relu4[0][0]']4e_identity_block_relu4 (Activ  (None, 14, 14, 1024  0          ['4e_identity_block_add[0][0]']  ation)                         )                                                                 4f_identity_block_conv1 (Conv2  (None, 14, 14, 256)  262400     ['4e_identity_block_relu4[0][0]']D)                                                                                               4f_identity_block_bn1 (BatchNo  (None, 14, 14, 256)  1024       ['4f_identity_block_conv1[0][0]']rmalization)                                                                                     4f_identity_block_relu1 (Activ  (None, 14, 14, 256)  0          ['4f_identity_block_bn1[0][0]']  ation)                                                                                           4f_identity_block_conv2 (Conv2  (None, 14, 14, 256)  590080     ['4f_identity_block_relu1[0][0]']D)                                                                                               4f_identity_block_bn2 (BatchNo  (None, 14, 14, 256)  1024       ['4f_identity_block_conv2[0][0]']rmalization)                                                                                     4f_identity_block_relu2 (Activ  (None, 14, 14, 256)  0          ['4f_identity_block_bn2[0][0]']  ation)                                                                                           4f_identity_block_conv3 (Conv2  (None, 14, 14, 1024  263168     ['4f_identity_block_relu2[0][0]']D)                             )                                                                 4f_identity_block_bn3 (BatchNo  (None, 14, 14, 1024  4096       ['4f_identity_block_conv3[0][0]']rmalization)                   )                                                                 4f_identity_block_add (Add)    (None, 14, 14, 1024  0           ['4f_identity_block_bn3[0][0]',  )                                 '4e_identity_block_relu4[0][0]']4f_identity_block_relu4 (Activ  (None, 14, 14, 1024  0          ['4f_identity_block_add[0][0]']  ation)                         )                                                                 5a_conv_block_conv1 (Conv2D)   (None, 7, 7, 512)    524800      ['4f_identity_block_relu4[0][0]']5a_conv_block_bn1 (BatchNormal  (None, 7, 7, 512)   2048        ['5a_conv_block_conv1[0][0]']    ization)                                                                                         5a_conv_block_relu1 (Activatio  (None, 7, 7, 512)   0           ['5a_conv_block_bn1[0][0]']      n)                                                                                               5a_conv_block_conv2 (Conv2D)   (None, 7, 7, 512)    2359808     ['5a_conv_block_relu1[0][0]']    5a_conv_block_bn2 (BatchNormal  (None, 7, 7, 512)   2048        ['5a_conv_block_conv2[0][0]']    ization)                                                                                         5a_conv_block_relu2 (Activatio  (None, 7, 7, 512)   0           ['5a_conv_block_bn2[0][0]']      n)                                                                                               5a_conv_block_conv3 (Conv2D)   (None, 7, 7, 2048)   1050624     ['5a_conv_block_relu2[0][0]']    5a_conv_block_res_conv (Conv2D  (None, 7, 7, 2048)  2099200     ['4f_identity_block_relu4[0][0]'])                                                                                                5a_conv_block_bn3 (BatchNormal  (None, 7, 7, 2048)  8192        ['5a_conv_block_conv3[0][0]']    ization)                                                                                         5a_conv_block_res_bn (BatchNor  (None, 7, 7, 2048)  8192        ['5a_conv_block_res_conv[0][0]'] malization)                                                                                      5a_conv_block_add (Add)        (None, 7, 7, 2048)   0           ['5a_conv_block_bn3[0][0]',      '5a_conv_block_res_bn[0][0]']   5a_conv_block_relu4 (Activatio  (None, 7, 7, 2048)  0           ['5a_conv_block_add[0][0]']      n)                                                                                               5b_identity_block_conv1 (Conv2  (None, 7, 7, 512)   1049088     ['5a_conv_block_relu4[0][0]']    D)                                                                                               5b_identity_block_bn1 (BatchNo  (None, 7, 7, 512)   2048        ['5b_identity_block_conv1[0][0]']rmalization)                                                                                     5b_identity_block_relu1 (Activ  (None, 7, 7, 512)   0           ['5b_identity_block_bn1[0][0]']  ation)                                                                                           5b_identity_block_conv2 (Conv2  (None, 7, 7, 512)   2359808     ['5b_identity_block_relu1[0][0]']D)                                                                                               5b_identity_block_bn2 (BatchNo  (None, 7, 7, 512)   2048        ['5b_identity_block_conv2[0][0]']rmalization)                                                                                     5b_identity_block_relu2 (Activ  (None, 7, 7, 512)   0           ['5b_identity_block_bn2[0][0]']  ation)                                                                                           5b_identity_block_conv3 (Conv2  (None, 7, 7, 2048)  1050624     ['5b_identity_block_relu2[0][0]']D)                                                                                               5b_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5b_identity_block_conv3[0][0]']rmalization)                                                                                     5b_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5b_identity_block_bn3[0][0]',  '5a_conv_block_relu4[0][0]']    5b_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5b_identity_block_add[0][0]']  ation)                                                                                           5c_identity_block_conv1 (Conv2  (None, 7, 7, 512)   1049088     ['5b_identity_block_relu4[0][0]']D)                                                                                               5c_identity_block_bn1 (BatchNo  (None, 7, 7, 512)   2048        ['5c_identity_block_conv1[0][0]']rmalization)                                                                                     5c_identity_block_relu1 (Activ  (None, 7, 7, 512)   0           ['5c_identity_block_bn1[0][0]']  ation)                                                                                           5c_identity_block_conv2 (Conv2  (None, 7, 7, 512)   2359808     ['5c_identity_block_relu1[0][0]']D)                                                                                               5c_identity_block_bn2 (BatchNo  (None, 7, 7, 512)   2048        ['5c_identity_block_conv2[0][0]']rmalization)                                                                                     5c_identity_block_relu2 (Activ  (None, 7, 7, 512)   0           ['5c_identity_block_bn2[0][0]']  ation)                                                                                           5c_identity_block_conv3 (Conv2  (None, 7, 7, 2048)  1050624     ['5c_identity_block_relu2[0][0]']D)                                                                                               5c_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5c_identity_block_conv3[0][0]']rmalization)                                                                                     5c_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5c_identity_block_bn3[0][0]',  '5b_identity_block_relu4[0][0]']5c_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5c_identity_block_add[0][0]']  ation)                                                                                           avg_pool (AveragePooling2D)    (None, 1, 1, 2048)   0           ['5c_identity_block_relu4[0][0]']flatten_2 (Flatten)            (None, 2048)         0           ['avg_pool[0][0]']               fc1000 (Dense)                 (None, 1000)         2049000     ['flatten_2[0][0]']              ==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

pytorch

n_class = 4
''' Same Padding '''
def autopad(k, p=None):  # kernel, padding# Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn p''' Identity Block '''
class IdentityBlock(nn.Module):def __init__(self, in_channel, kernel_size, filters):super(IdentityBlock, self).__init__()filters1, filters2, filters3 = filtersself.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters1, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters1),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters1, filters2, kernel_size, stride=1, padding=autopad(kernel_size), bias=False),nn.BatchNorm2d(filters2),nn.ReLU(True))self.conv3 = nn.Sequential(nn.Conv2d(filters2, filters3, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters3))self.relu = nn.ReLU(True)def forward(self, x):x1 = self.conv1(x)x1 = self.conv2(x1)x1 = self.conv3(x1)x = x1 + xself.relu(x)return x''' Conv Block '''
class ConvBlock(nn.Module):def __init__(self, in_channel, kernel_size, filters, stride=2):super(ConvBlock, self).__init__()filters1, filters2, filters3 = filtersself.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters1, 1, stride=stride, padding=0, bias=False),nn.BatchNorm2d(filters1),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters1, filters2, kernel_size, stride=1, padding=autopad(kernel_size), bias=False),nn.BatchNorm2d(filters2),nn.ReLU(True))self.conv3 = nn.Sequential(nn.Conv2d(filters2, filters3, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(filters3))self.conv4 = nn.Sequential(nn.Conv2d(in_channel, filters3, 1, stride=stride, padding=0, bias=False),nn.BatchNorm2d(filters3))self.relu = nn.ReLU(True)def forward(self, x):x1 = self.conv1(x)x1 = self.conv2(x1)x1 = self.conv3(x1)x2 = self.conv4(x)x = x1 + x2self.relu(x)return x''' 构建ResNet-50 '''
class ResNet50(nn.Module):def __init__(self, classes=1000):super(ResNet50, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=0))self.conv2 = nn.Sequential(ConvBlock(64, 3, [64, 64, 256], stride=1),IdentityBlock(256, 3, [64, 64, 256]),IdentityBlock(256, 3, [64, 64, 256]))self.conv3 = nn.Sequential(ConvBlock(256, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]),IdentityBlock(512, 3, [128, 128, 512]))self.conv4 = nn.Sequential(ConvBlock(512, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]),IdentityBlock(1024, 3, [256, 256, 1024]))self.conv5 = nn.Sequential(ConvBlock(1024, 3, [512, 512, 2048]),IdentityBlock(2048, 3, [512, 512, 2048]),IdentityBlock(2048, 3, [512, 512, 2048]))self.pool = nn.AvgPool2d(kernel_size=7, stride=7, padding=0)self.fc = nn.Linear(2048, n_class)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.pool(x)x = torch.flatten(x, start_dim=1)x = self.fc(x)return xmodel = ResNet50().to(device)
''' 显示网络结构 '''
torchsummary.summary(model, (3, 224, 224))
#torchinfo.summary(model)
print(model)
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 55, 55]               0Conv2d-5           [-1, 64, 55, 55]           4,096BatchNorm2d-6           [-1, 64, 55, 55]             128ReLU-7           [-1, 64, 55, 55]               0Conv2d-8           [-1, 64, 55, 55]          36,864BatchNorm2d-9           [-1, 64, 55, 55]             128ReLU-10           [-1, 64, 55, 55]               0Conv2d-11          [-1, 256, 55, 55]          16,384BatchNorm2d-12          [-1, 256, 55, 55]             512Conv2d-13          [-1, 256, 55, 55]          16,384BatchNorm2d-14          [-1, 256, 55, 55]             512ReLU-15          [-1, 256, 55, 55]               0ConvBlock-16          [-1, 256, 55, 55]               0Conv2d-17           [-1, 64, 55, 55]          16,384BatchNorm2d-18           [-1, 64, 55, 55]             128ReLU-19           [-1, 64, 55, 55]               0Conv2d-20           [-1, 64, 55, 55]          36,864BatchNorm2d-21           [-1, 64, 55, 55]             128ReLU-22           [-1, 64, 55, 55]               0Conv2d-23          [-1, 256, 55, 55]          16,384BatchNorm2d-24          [-1, 256, 55, 55]             512ReLU-25          [-1, 256, 55, 55]               0IdentityBlock-26          [-1, 256, 55, 55]               0Conv2d-27           [-1, 64, 55, 55]          16,384BatchNorm2d-28           [-1, 64, 55, 55]             128ReLU-29           [-1, 64, 55, 55]               0Conv2d-30           [-1, 64, 55, 55]          36,864BatchNorm2d-31           [-1, 64, 55, 55]             128ReLU-32           [-1, 64, 55, 55]               0Conv2d-33          [-1, 256, 55, 55]          16,384BatchNorm2d-34          [-1, 256, 55, 55]             512ReLU-35          [-1, 256, 55, 55]               0IdentityBlock-36          [-1, 256, 55, 55]               0Conv2d-37          [-1, 128, 28, 28]          32,768BatchNorm2d-38          [-1, 128, 28, 28]             256ReLU-39          [-1, 128, 28, 28]               0Conv2d-40          [-1, 128, 28, 28]         147,456BatchNorm2d-41          [-1, 128, 28, 28]             256ReLU-42          [-1, 128, 28, 28]               0Conv2d-43          [-1, 512, 28, 28]          65,536BatchNorm2d-44          [-1, 512, 28, 28]           1,024Conv2d-45          [-1, 512, 28, 28]         131,072BatchNorm2d-46          [-1, 512, 28, 28]           1,024ReLU-47          [-1, 512, 28, 28]               0ConvBlock-48          [-1, 512, 28, 28]               0Conv2d-49          [-1, 128, 28, 28]          65,536BatchNorm2d-50          [-1, 128, 28, 28]             256ReLU-51          [-1, 128, 28, 28]               0Conv2d-52          [-1, 128, 28, 28]         147,456BatchNorm2d-53          [-1, 128, 28, 28]             256ReLU-54          [-1, 128, 28, 28]               0Conv2d-55          [-1, 512, 28, 28]          65,536BatchNorm2d-56          [-1, 512, 28, 28]           1,024ReLU-57          [-1, 512, 28, 28]               0IdentityBlock-58          [-1, 512, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          65,536BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62          [-1, 128, 28, 28]         147,456BatchNorm2d-63          [-1, 128, 28, 28]             256ReLU-64          [-1, 128, 28, 28]               0Conv2d-65          [-1, 512, 28, 28]          65,536BatchNorm2d-66          [-1, 512, 28, 28]           1,024ReLU-67          [-1, 512, 28, 28]               0IdentityBlock-68          [-1, 512, 28, 28]               0Conv2d-69          [-1, 128, 28, 28]          65,536BatchNorm2d-70          [-1, 128, 28, 28]             256ReLU-71          [-1, 128, 28, 28]               0Conv2d-72          [-1, 128, 28, 28]         147,456BatchNorm2d-73          [-1, 128, 28, 28]             256ReLU-74          [-1, 128, 28, 28]               0Conv2d-75          [-1, 512, 28, 28]          65,536BatchNorm2d-76          [-1, 512, 28, 28]           1,024ReLU-77          [-1, 512, 28, 28]               0IdentityBlock-78          [-1, 512, 28, 28]               0Conv2d-79          [-1, 256, 14, 14]         131,072BatchNorm2d-80          [-1, 256, 14, 14]             512ReLU-81          [-1, 256, 14, 14]               0Conv2d-82          [-1, 256, 14, 14]         589,824BatchNorm2d-83          [-1, 256, 14, 14]             512ReLU-84          [-1, 256, 14, 14]               0Conv2d-85         [-1, 1024, 14, 14]         262,144BatchNorm2d-86         [-1, 1024, 14, 14]           2,048Conv2d-87         [-1, 1024, 14, 14]         524,288BatchNorm2d-88         [-1, 1024, 14, 14]           2,048ReLU-89         [-1, 1024, 14, 14]               0ConvBlock-90         [-1, 1024, 14, 14]               0Conv2d-91          [-1, 256, 14, 14]         262,144BatchNorm2d-92          [-1, 256, 14, 14]             512ReLU-93          [-1, 256, 14, 14]               0Conv2d-94          [-1, 256, 14, 14]         589,824BatchNorm2d-95          [-1, 256, 14, 14]             512ReLU-96          [-1, 256, 14, 14]               0Conv2d-97         [-1, 1024, 14, 14]         262,144BatchNorm2d-98         [-1, 1024, 14, 14]           2,048ReLU-99         [-1, 1024, 14, 14]               0IdentityBlock-100         [-1, 1024, 14, 14]               0Conv2d-101          [-1, 256, 14, 14]         262,144BatchNorm2d-102          [-1, 256, 14, 14]             512ReLU-103          [-1, 256, 14, 14]               0Conv2d-104          [-1, 256, 14, 14]         589,824BatchNorm2d-105          [-1, 256, 14, 14]             512ReLU-106          [-1, 256, 14, 14]               0Conv2d-107         [-1, 1024, 14, 14]         262,144BatchNorm2d-108         [-1, 1024, 14, 14]           2,048ReLU-109         [-1, 1024, 14, 14]               0IdentityBlock-110         [-1, 1024, 14, 14]               0Conv2d-111          [-1, 256, 14, 14]         262,144BatchNorm2d-112          [-1, 256, 14, 14]             512ReLU-113          [-1, 256, 14, 14]               0Conv2d-114          [-1, 256, 14, 14]         589,824BatchNorm2d-115          [-1, 256, 14, 14]             512ReLU-116          [-1, 256, 14, 14]               0Conv2d-117         [-1, 1024, 14, 14]         262,144BatchNorm2d-118         [-1, 1024, 14, 14]           2,048ReLU-119         [-1, 1024, 14, 14]               0IdentityBlock-120         [-1, 1024, 14, 14]               0Conv2d-121          [-1, 256, 14, 14]         262,144BatchNorm2d-122          [-1, 256, 14, 14]             512ReLU-123          [-1, 256, 14, 14]               0Conv2d-124          [-1, 256, 14, 14]         589,824BatchNorm2d-125          [-1, 256, 14, 14]             512ReLU-126          [-1, 256, 14, 14]               0Conv2d-127         [-1, 1024, 14, 14]         262,144BatchNorm2d-128         [-1, 1024, 14, 14]           2,048ReLU-129         [-1, 1024, 14, 14]               0IdentityBlock-130         [-1, 1024, 14, 14]               0Conv2d-131          [-1, 256, 14, 14]         262,144BatchNorm2d-132          [-1, 256, 14, 14]             512ReLU-133          [-1, 256, 14, 14]               0Conv2d-134          [-1, 256, 14, 14]         589,824BatchNorm2d-135          [-1, 256, 14, 14]             512ReLU-136          [-1, 256, 14, 14]               0Conv2d-137         [-1, 1024, 14, 14]         262,144BatchNorm2d-138         [-1, 1024, 14, 14]           2,048ReLU-139         [-1, 1024, 14, 14]               0IdentityBlock-140         [-1, 1024, 14, 14]               0Conv2d-141            [-1, 512, 7, 7]         524,288BatchNorm2d-142            [-1, 512, 7, 7]           1,024ReLU-143            [-1, 512, 7, 7]               0Conv2d-144            [-1, 512, 7, 7]       2,359,296BatchNorm2d-145            [-1, 512, 7, 7]           1,024ReLU-146            [-1, 512, 7, 7]               0Conv2d-147           [-1, 2048, 7, 7]       1,048,576BatchNorm2d-148           [-1, 2048, 7, 7]           4,096Conv2d-149           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-150           [-1, 2048, 7, 7]           4,096ReLU-151           [-1, 2048, 7, 7]               0ConvBlock-152           [-1, 2048, 7, 7]               0Conv2d-153            [-1, 512, 7, 7]       1,048,576BatchNorm2d-154            [-1, 512, 7, 7]           1,024ReLU-155            [-1, 512, 7, 7]               0Conv2d-156            [-1, 512, 7, 7]       2,359,296BatchNorm2d-157            [-1, 512, 7, 7]           1,024ReLU-158            [-1, 512, 7, 7]               0Conv2d-159           [-1, 2048, 7, 7]       1,048,576BatchNorm2d-160           [-1, 2048, 7, 7]           4,096ReLU-161           [-1, 2048, 7, 7]               0IdentityBlock-162           [-1, 2048, 7, 7]               0Conv2d-163            [-1, 512, 7, 7]       1,048,576BatchNorm2d-164            [-1, 512, 7, 7]           1,024ReLU-165            [-1, 512, 7, 7]               0Conv2d-166            [-1, 512, 7, 7]       2,359,296BatchNorm2d-167            [-1, 512, 7, 7]           1,024ReLU-168            [-1, 512, 7, 7]               0Conv2d-169           [-1, 2048, 7, 7]       1,048,576BatchNorm2d-170           [-1, 2048, 7, 7]           4,096ReLU-171           [-1, 2048, 7, 7]               0IdentityBlock-172           [-1, 2048, 7, 7]               0AvgPool2d-173           [-1, 2048, 1, 1]               0Linear-174                    [-1, 4]           8,196
================================================================
Total params: 23,516,228
Trainable params: 23,516,228
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 270.43
Params size (MB): 89.71
Estimated Total Size (MB): 360.71
----------------------------------------------------------------
ResNet50((conv1): Sequential((0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()(3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))(conv2): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv3): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(3): IdentityBlock((conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv4): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(3): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(4): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(5): IdentityBlock((conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(conv5): Sequential((0): ConvBlock((conv1): Sequential((0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(1): IdentityBlock((conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True))(2): IdentityBlock((conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(conv3): Sequential((0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(relu): ReLU(inplace=True)))(pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fc): Linear(in_features=2048, out_features=4, bias=True)
)

六、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

损失函数(loss):用于衡量模型在训练期间的准确率。

优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。

指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。

TensorFlow

# 设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#要训练什么参数/
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.92)#学习率每5个epoch衰减成原来的1/10
loss_fn = nn.CrossEntropyLoss()
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共900张图片num_batches = len(dataloader)   # 批次数目,29(900/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,8(255/32=8,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss
epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_ds, model, loss_fn, optimizer)scheduler.step()#学习率衰减model.eval()epoch_test_acc, epoch_test_loss = test(test_ds, model, loss_fn)# 保存最优模型if epoch_test_acc > best_acc:best_acc = epoch_train_accstate = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)
# 设置优化器,我这里改变了学习率。
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 设置优化器,我这里改变了学习率。
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])
''' 设置超参数 '''
start_epoch = 0
epochs      = 10
learn_rate  = 1e-7 # 初始学习率
loss_fn     = nn.CrossEntropyLoss()  # 创建损失函数
optimizer   = torch.optim.Adam(model.parameters(),lr=learn_rate)
train_loss  = []
train_acc   = []
test_loss   = []
test_acc    = []
epoch_best_acc = 0

七、训练模型

TensorFlow

epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
pytoch
epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)scheduler.step()#学习率衰减model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最优模型if epoch_test_acc > best_acc:best_acc = epoch_train_accstate = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)

可视化数据

tensorFlow

''' 可视化数据 '''
plt.figure(fgsize=(10, 5))
plt.suptitle("微信公众号:K同学啊")
for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")plt.imshow(images[1].numpy().astype("uint8"))
Pytorch
''' 数据可视化 '''
def displayData(imgs, root='', flag=False):# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)plt.figure('Data Visualization', figsize=(10, 5)) for i, imgs in enumerate(imgs[:8]):# 维度顺序调整 [3, 224, 224]->[224, 224, 3]npimg = imgs.numpy().transpose((1, 2, 0))# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 4, i+1)plt.imshow(npimg)  # cmap=plt.cm.binaryplt.title(list(classeNames)[labels[i]])plt.axis('off')plt.savefig(os.path.join(root, 'DatasetDisplay.png'))if flag:plt.show()else:plt.close('all')

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nQYECgKD-1678455456179)(attachment:9d4957ce-8541-4129-9cca-c291835e5af9.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9kJ1EZGD-1678455456180)(attachment:656a4e3e-17f9-4a59-bb78-0d3980063c32.png)]

八、模型评估

''' 模型评估 '''
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle('K同学啊')
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training adn Validation Loss')
plt.show()

在这里插入图片描述

这篇关于ResNet J1的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/193752

相关文章

Resnet图像识别入门——残差结构

桃树、杏树、梨树,你不让我,我不让你,都开满了花赶趟儿。红的像火,粉的像霞,白的像雪。花里带着甜味儿;闭了眼,树上仿佛已经满是桃儿、杏儿、梨儿。花下成千成百的蜜蜂嗡嗡地闹着,大小的蝴蝶飞来飞去。野花遍地是:杂样儿,有名字的,没名字的,散在草丛里,像眼睛,像星星,还眨呀眨的。 朱自清在写《春》的时候,或许也没有完全认清春天的所有花,以至于写出了“有名字的,没名字的,散在草丛中”这样的句子。

【ShuQiHere】从残差思想到 ResNet:深度学习的突破性创新

【ShuQiHere】引言 在深度学习的迅速发展中,卷积神经网络(CNN)凭借其在计算机视觉领域的出色表现,已经成为一种主流的神经网络架构。然而,随着网络层数的增加,研究人员逐渐发现了一个关键问题:梯度消失 😖 和 梯度爆炸 💥,这使得训练非常深的网络变得极其困难。为了解决这一问题,残差思想 💡 被提出,并在 2015 年由 Kaiming He 等人正式引入 ResNet 中。这一创新不

【python 走进pytotch】pytorch实现用Resnet提取特征

无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家。教程不仅是零基础,通俗易懂, 而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。点这里可以跳转到教程。人工智能教程 准备一张图片,pytorch可以方便地实现用预训练的网络提取特征。 下面我们用pytorch提取图片采用预训练网络resnet50,提取图片特征。 # -*- coding: utf-8 -*-import os

【深度学习 卷积】利用ResNet-50模型实现高效GPU图片预测

本文介绍了如何使用训练好的ResNet-50模型进行图片预测。通过详细阐述模型原理、训练过程及预测步骤,帮助读者掌握基于深度学习的图片识别技术。 一、引言 近年来,深度学习技术在计算机视觉领域取得了显著成果,特别是卷积神经网络(CNN)在图像识别、分类等方面表现出色。ResNet-50作为一种经典的CNN模型,以其强大的特征提取能力和较高的预测准确率,在众多领域得到了广泛应用。本文将介绍如何使

深度学习-TensorFlow2:TensorFlow2 创建CNN神经网络模型【ResNet模型】

自定义ResNet神经网络-Tensorflow【cifar100分类数据集】 import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 放在 import tensorflow as tf 之前才有效import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras

深度学习-Pytorch:Pytorch 创建CNN神经网络模型【ResNet模型】

一、自定义ResNet神经网络-Pytorch【cifar10图片分类数据集】 import torchfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision import transformsfrom torch import nn, optimfrom torch

CV-CNN-2016:GoogleNet-V4【用ResNet模型的残差连接(Residual Connection)思想改进GoogleNet-V3的结构】

Inception V4研究了Inception模块与残差连接的结合。 ResNet结构大大地加深了网络深度,还极大地提升了训练速度,同时性能也有提升。 Inception V4主要利用残差连接(Residual Connection)来改进V3结构,得到Inception-ResNet-v1,Inception-ResNet-v2,Inception-v4网络。 ResNet的残差结构如下

ncnn之resnet图像分类网络模型部署

文章目录 一、ONNX模型导出二、模型转换ONNX2NCNN2.1 net.param文件2.2 net.bin文件 三、ncnn模型推理四、编译与运行 一、ONNX模型导出 将 PyTorch 模型导出为 ONNX 格式,代码如下: import torchimport torchvision.models as modelsimport torch.onnx as o

对ResNet和DenseNet的一点比较

在之前的文章中,我曾经详细的分析过DenseNet。随着时间的推移,越来越多的人开始关注DenseNet并且把它应用到实际的项目中(我从百度搜索的统计上得到这个结论 ?)。所以我决定也在应用这方面写一点。 当前卷积网络使用最广泛的两种结构是ResNet和Inception,当然还有各种衍生系列以及它俩的合体Inception-Resnet。从结构特点上,它们都是某个小结构的组合,包括DenseN

ResNet介绍

ResNet介绍     1 简要概括     ResNet(Residual Neural Network)由微软研究院的Kaiming He等四名华人提出,通过使用ResNet Unit成功训练出了152层的神经网络,并在ILSVRC2015比赛中取得冠军,在top5上的错误率为3.57%,同时参数量比VGGNet低,效果非常突出。ResNet的结构可以极快的加速神经网络的训练,模型的