TensorFlow2实战-系列教程14:Resnet实战2

2024-02-01 09:52

本文主要是介绍TensorFlow2实战-系列教程14:Resnet实战2,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

4、训练脚本train.py解读------创建模型

def get_model():model = resnet50.ResNet50()if config.model == "resnet34":model = resnet34.ResNet34()if config.model == "resnet101":model = resnet101.ResNet101()if config.model == "resnet152":model = resnet152.ResNet152()model.build(input_shape=(None, config.image_height, config.image_width, config.channels))model.summary()tf.keras.utils.plot_model(model, to_file='model.png')return model# create model
model = get_model()

调用get_model()函数构建模型

get_model()函数:

  1. 通过resnet50.py调用ResNet50类,构建ResNet50模型
  2. 如果在配置参数中设置的是"resnet34"、“resnet101”、“resnet152”,则会对应使用(resnet34.py调用ResNet34类,构建ResNet34模型)、(resnet101.py调用ResNet101类,构建ResNet101模型)、(resnet152.py调用ResNet152类,构建ResNet152模型)
  3. 准备模型以供训练或评估,
  4. 输出模型的概览
  5. 创建了模型的结构图,plot_model 函数从 Keras 工具包中生成模型的可视化表示,指定了保存路径

5、模型构建解析------models/resnet50.py

import tensorflow as tf
from models.residual_block import build_res_block_2
from config import NUM_CLASSESclass ResNet50(tf.keras.Model):def __init__(self, num_classes=NUM_CLASSES):super(ResNet50, self).__init__()self.pre1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same')self.pre2 = tf.keras.layers.BatchNormalization()self.pre3 = tf.keras.layers.Activation(tf.keras.activations.relu)self.pre4 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), strides=2)self.layer1 = build_res_block_2(filter_num=64, blocks=3)self.layer2 = build_res_block_2(filter_num=128, blocks=4, stride=2)self.layer3 = build_res_block_2(filter_num=256, blocks=6, stride=2)self.layer4 = build_res_block_2(filter_num=512, blocks=3, stride=2)self.avgpool = tf.keras.layers.GlobalAveragePooling2D()self.fc1 = tf.keras.layers.Dense(units=1000, activation=tf.keras.activations.relu)self.drop_out = tf.keras.layers.Dropout(rate=0.5)self.fc2 = tf.keras.layers.Dense(units=num_classes, activation=tf.keras.activations.softmax)def call(self, inputs, training=None, mask=None):pre1 = self.pre1(inputs)pre2 = self.pre2(pre1, training=training)pre3 = self.pre3(pre2)pre4 = self.pre4(pre3)l1 = self.layer1(pre4, training=training)l2 = self.layer2(l1, training=training)l3 = self.layer3(l2, training=training)l4 = self.layer4(l3, training=training)avgpool = self.avgpool(l4)fc1 = self.fc1(avgpool)drop = self.drop_out(fc1)out = self.fc2(drop)return out

class ResNet50(tf.keras.Model),这个类定义了ResNet50模型的结构,以及前向传播的方式、顺序

ResNet50类解析:

  1. 构造函数,传入了预测的类别数
  2. 初始化
  3. pre1 ,定义一个二维卷积,输出64个特征图,7x7的卷积,步长为2
  4. pre2 ,定义一个批归一化
  5. pre3,定义一个ReLU激活函数
  6. pre4,一个二维的最大池化
  7. 依次通过build_res_block_2()函数定义4个残差块
  8. 定义一个全局平均池化
  9. 定义一个全连接层,输出维度为1000
  10. 定义一个dropout
  11. 定义一个输出层的全连接层
  12. 前向传播函数,传入输入值
  13. 依次经过pre1、pre2、pre3、pre4,即卷积、批归一化、ReLU、最大池化
  14. 依次经过layer1 、layer2 、layer3 、layer4 等四个残差块
  15. 将layer4 的输出经过平局池化
  16. 依次经过两个全连接层

6、模型构建解析------models/residual_block.py

  • BottleNeck类
  • build_res_block_2()函数
  • build_res_block_2()函数通过调用BottleNeck类构建残差块
class BottleNeck(tf.keras.layers.Layer):def __init__(self, filter_num, stride=1,with_downsample=True):super(BottleNeck, self).__init__()self.with_downsample = with_downsampleself.conv1 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(1, 1), strides=1, padding='same')self.bn1 = tf.keras.layers.BatchNormalization()self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=stride, padding='same')self.bn2 = tf.keras.layers.BatchNormalization()self.conv3 = tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=1, padding='same')self.bn3 = tf.keras.layers.BatchNormalization()self.downsample = tf.keras.Sequential()self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=stride))self.downsample.add(tf.keras.layers.BatchNormalization())def call(self, inputs, training=None):identity = self.downsample(inputs)conv1 = self.conv1(inputs)bn1 = self.bn1(conv1, training=training)relu1 = tf.nn.relu(bn1)conv2 = self.conv2(relu1)bn2 = self.bn2(conv2, training=training)relu2 = tf.nn.relu(bn2)conv3 = self.conv3(relu2)bn3 = self.bn3(conv3, training=training)if self.with_downsample == True:output = tf.nn.relu(tf.keras.layers.add([identity, bn3]))else:output = tf.nn.relu(tf.keras.layers.add([inputs, bn3]))return output

BottleNeck类解析:

  1. 继承tf.keras.layers.Layer
  2. 构造函数,传入 特征图个数、步长、是否下采样等参数
  3. 初始化
  4. 是否进行下采样参数
  5. 定义一个1x1,步长为1的二维卷积conv1
  6. conv1 对应的批归一化
  7. 定义一个3x3,步长为1的二维卷积conv2
  8. conv2 对应的批归一化
  9. 定义一个3x3,步长为1的二维卷积conv2
  10. conv3 对应的批归一化
  11. 定义一个下采样层(self.downsample),这个层是一个包含卷积层和批量归一化的 Sequential 模型,用于匹配输入和残差的维度
  12. call()函数为前向传播
  13. 应用下采样
  14. 应用三层卷积和批量归一化以及对应的ReLU
  15. with_downsample == True:
  16. 启用下采样,将下采样后的输入(identity)与最后一个卷积层的输出(bn3)相加
  17. 没有启用下采样,将原始输入(inputs)与最后一个卷积层的输出(bn3)相加
def build_res_block_2(filter_num, blocks, stride=1):res_block = tf.keras.Sequential()res_block.add(BottleNeck(filter_num, stride=stride))for _ in range(1, blocks):res_block.add(BottleNeck(filter_num, stride=1,with_downsample=False))    return res_block

build_res_block_2函数解析:

  1. 这个函数构建了一个包含多个BottleNeck层的残差块
  2. filter_num 是每个瓶颈层内卷积层的过滤器数量
  3. blocks 是要添加到顺序模型中的瓶颈层的数量
  4. stride 是卷积的步长,默认为 1
  5. 该函数初始化一个 Sequential 模型,并添加一个 BottleNeck 层作为第一层
  6. 然后,它迭代地添加额外的 BottleNeck 层,每个层的 stride=1 且
    with_downsample=False(除第一个之外)
  7. 此函数返回组装好的顺序模型,代表一个残差块

Resnet实战1
Resnet实战2
Resnet实战3

这篇关于TensorFlow2实战-系列教程14:Resnet实战2的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/666884

相关文章

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

Pandas使用SQLite3实战

《Pandas使用SQLite3实战》本文主要介绍了Pandas使用SQLite3实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1 环境准备2 从 SQLite3VlfrWQzgt 读取数据到 DataFrame基础用法:读

Linux卸载自带jdk并安装新jdk版本的图文教程

《Linux卸载自带jdk并安装新jdk版本的图文教程》在Linux系统中,有时需要卸载预装的OpenJDK并安装特定版本的JDK,例如JDK1.8,所以本文给大家详细介绍了Linux卸载自带jdk并... 目录Ⅰ、卸载自带jdkⅡ、安装新版jdkⅠ、卸载自带jdk1、输入命令查看旧jdkrpm -qa

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

springboot简单集成Security配置的教程

《springboot简单集成Security配置的教程》:本文主要介绍springboot简单集成Security配置的教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录集成Security安全框架引入依赖编写配置类WebSecurityConfig(自定义资源权限规则

MySQL Workbench 安装教程(保姆级)

《MySQLWorkbench安装教程(保姆级)》MySQLWorkbench是一款强大的数据库设计和管理工具,本文主要介绍了MySQLWorkbench安装教程,文中通过图文介绍的非常详细,对大... 目录前言:详细步骤:一、检查安装的数据库版本二、在官网下载对应的mysql Workbench版本,要是

通过Docker Compose部署MySQL的详细教程

《通过DockerCompose部署MySQL的详细教程》DockerCompose作为Docker官方的容器编排工具,为MySQL数据库部署带来了显著优势,下面小编就来为大家详细介绍一... 目录一、docker Compose 部署 mysql 的优势二、环境准备与基础配置2.1 项目目录结构2.2 基

Linux安装MySQL的教程

《Linux安装MySQL的教程》:本文主要介绍Linux安装MySQL的教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux安装mysql1.Mysql官网2.我的存放路径3.解压mysql文件到当前目录4.重命名一下5.创建mysql用户组和用户并修

Python实战之屏幕录制功能的实现

《Python实战之屏幕录制功能的实现》屏幕录制,即屏幕捕获,是指将计算机屏幕上的活动记录下来,生成视频文件,本文主要为大家介绍了如何使用Python实现这一功能,希望对大家有所帮助... 目录屏幕录制原理图像捕获音频捕获编码压缩输出保存完整的屏幕录制工具高级功能实时预览增加水印多平台支持屏幕录制原理屏幕