秃姐学AI系列之:残差网络 ResNet

2024-08-28 00:36

本文主要是介绍秃姐学AI系列之:残差网络 ResNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

残差网络——ResNet

残差块思想

ResNet块细节

ResNet架构

总结

代码实现

残差块

两种 ResNet 块的情况 

ResNet 模型

QA


由上图发现,只有当较复杂的函数类包含较小的函数类时,才能确保提高它们的性能。

对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function)f(x)=x,新模型和原模型将同样有效。 同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。

针对这一问题,何恺明等人提出了残差网络(ResNet) 。它在2015年的 ImageNet 图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计。

残差网络——ResNet

残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。

于是,残差块(residual blocks)便诞生了,这个设计对如何建立深层神经网络产生了深远的影响。 凭借它,ResNet赢得了2015年ImageNet大规模视觉识别挑战赛。

残差块思想

残差块加入快速通道来得到f(x) = x + g(x) 的结构

如下图所示,假设我们的原始输入为x,而希望学出的理想映射为 f(x)(作为上方激活函数的输入)。

左图虚线框中的部分需要直接拟合出该映射 f(x),而右图虚线框中的部分则需要拟合出残差映射 f(x)−x。残差映射在现实中往往更容易优化

开头提到的恒等映射作为我们希望学出的理想映射 f(x),我们只需将右图虚线框内上方的加权运算(如仿射)的权重和偏置参数设成 0,那么 f(x) 即为恒等映射。 实际中,当理想映射 f(x) 极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。

右图是 ResNet 的基础架构–残差块(residual block)。 在残差块中,输入可通过跨层数据线路更快地向前传播。

相当于ResNet觉得,你就算虚线框里面所有层都没学到东西,下一层还是可以接收到这层的上一层传递来的东西(残差连接)即一个简单的直接传递的小模型。这个想法从函数的角度来说,可以认为更大、更复杂的模型里面包含一个小模型。

ResNet块细节

ResNet 是从 VGG 过来的,所以采用的是 3x3Conv

以下是 ResNet 块的两个不同的实现

右边存在的意义是:如果虚线的block对通道做了变换,那直接的X加不回去了,所以需要用卷积来对x做一个通道数的变换用于相加 。

ResNet架构

ResNet 最主要的思想就是单拎出来一条路让你可以把输入和输出加起来

抛开这个其他的你可以认为和 VGG 以及GoogLeNet 很像,也是由5个Stage拼成,只是把组合成网络的 Stage 替换成了 ResNet 块

  • 一个高宽减半的 ResNet 块(步幅为2)(那个支线上有Conv的Block,用来把输入的通道数翻一倍)
  • 重复多个高宽不变的 ResNet 块 

总结

  • 残差块使得很深的网络更加容易训练

    • 甚至可以训练以前层的网络

  • 残差网络对随后的深层神经网络设计产生了深远的影响,无论是卷积类网络还是全连接类网络 

  • 学习嵌套函数(nested function)是训练神经网络的理想情况。在深层神经网络中,学习另一层作为恒等映射(identity function)较容易(尽管这是一个极端情况)。

  • 残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。

  • 利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播。

代码实现

残差块

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):  #@savedef __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

两种 ResNet 块的情况 

输入和输出形状一致

blk = Residual(3,3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape# 输出
torch.Size([4, 3, 6, 6])

增加输出通道数的同时,减半输出的高和宽

blk = Residual(3, 6, use_1x1conv=True, strides= 2)
blk(X).shape# 输出
torch.Size([4, 6, 3, 3])

ResNet 模型

ResNet 的前两层跟之前介绍的 GoogLeNet 中的一样:

在输出通道数为 64、步幅为 2 的 7×7 卷积层后,

接步幅为 2 的 3×3 的最大汇聚层。

不同之处在于 ResNet 每个卷积层后增加了批量规范化层。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet 在后面接了 4 个由 Inception 块组成的模块。ResNet 则使用 4 个由残差块组成的模块,

每个模块使用若干个同样输出通道数的残差块。

第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为 2 的最大汇聚层,所以无须减小高和宽。

之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))

老规矩,不同模块的数据形状变化

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

QA

  • 残差概念体现在哪里?

可以理解成,因为 f(x) 是由 x 和 g(x) 相加得来的,x 又是由上一层网络训练的来的,可以被视为一个小网络的输出。所以整个 ResNet 就是先训练小网络,然后小网络 fit 不到的(小的差距)再由上面的层去补充。这就是残差(残留的差距)的概念。

  •  为什么 BN 需要定义两个,而 ReLU 不需要?

BN 是两个独立的层,每个层都有自己需要学的不同的参数,而 ReLU 没有什么学习性,所以公用一个层就可以

  • 训练 ACC 是不是在不 overfitting 的情况下,永远大于测试 ACC?

不一定哦,后面会看到当你做了大量的数据噪音的时候,测试精度会高于训练精度,因为你测试的时候不会添加噪声。

  • 为什么 ResNet 可以训练 100 层网络? 

 假设 g(x) 是在 f(x) 之外新加的一个层,那对于梯度的计算公式根据链式求导法展开,多出来的第一项就是新套的那层的输入和输出求导。假设加的这个层的拟合能力比较强,这一项会很快的变得特别小。一个很小的值乘我们之前那一层的梯度,梯度就会变得比之前小很多。梯度变小之后可以选择增大学习率,但是很有可能增大学习率也没啥用。因为也不能增的太大,f 这一层比 g 更靠近数据,如果增加太大那 g 这一层会变得不稳定。这就是为什么之前模型变深之后会出现梯度消失的问题。

主要原因就是层数叠加,梯度是一直做乘法。回传的时候就会出现底部的梯度特别小。

而 ResNet 是怎么解决这个问题的呢?

因为 ResNet 的网络设计使得它的梯度计算是相加的,哪怕有哪一块比较小也没关系,哪怕当g(x)不存在的时候去拟合,也有 f(x) 的梯度存在。

大数 + 小数没问题,但是 大数*小数问题很大! 

这篇关于秃姐学AI系列之:残差网络 ResNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1113230

相关文章

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言