Pytroch实现ResNet

2024-05-14 21:32
文章标签 实现 resnet pytroch

本文主要是介绍Pytroch实现ResNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

论文: Deep Residual Learning for Image Recognition

网络结构图:

 

 

import torch.nn as nn
import mathdef conv3x3(in_planes, out_planes, stride=1):return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):self.inplanes = 64super(ResNet, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AvgPool2d(7, stride=1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor i in range(1, blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return xdef resnet18(**kwargs):model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)return modeldef resnet34(**kwargs):model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)return modeldef resnet50(**kwargs):model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)return modeldef resnet101(**kwargs):model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)return modeldef resnet152(**kwargs):model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)return modelif __name__ == '__main__':# 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'# Examplenet18 = resnet18()print(net18)

源码:https://pytorch.org/docs/0.4.0/_modules/torchvision/models/resnet.html

这篇关于Pytroch实现ResNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/989901

相关文章

C++对象布局及多态实现探索之内存布局(整理的很多链接)

本文通过观察对象的内存布局,跟踪函数调用的汇编代码。分析了C++对象内存的布局情况,虚函数的执行方式,以及虚继承,等等 文章链接:http://dev.yesky.com/254/2191254.shtml      论C/C++函数间动态内存的传递 (2005-07-30)   当你涉及到C/C++的核心编程的时候,你会无止境地与内存管理打交道。 文章链接:http://dev.yesky

通过SSH隧道实现通过远程服务器上外网

搭建隧道 autossh -M 0 -f -D 1080 -C -N user1@remotehost##验证隧道是否生效,查看1080端口是否启动netstat -tuln | grep 1080## 测试ssh 隧道是否生效curl -x socks5h://127.0.0.1:1080 -I http://www.github.com 将autossh 设置为服务,隧道开机启动

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测 目录 时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测基本介绍程序设计参考资料 基本介绍 MATLAB实现LSTM时间序列未来多步预测-递归预测。LSTM是一种含有LSTM区块(blocks)或其他的一种类神经网络,文献或其他资料中LSTM区块可能被描述成智能网络单元,因为

vue项目集成CanvasEditor实现Word在线编辑器

CanvasEditor实现Word在线编辑器 官网文档:https://hufe.club/canvas-editor-docs/guide/schema.html 源码地址:https://github.com/Hufe921/canvas-editor 前提声明: 由于CanvasEditor目前不支持vue、react 等框架开箱即用版,所以需要我们去Git下载源码,拿到其中两个主

android一键分享功能部分实现

为什么叫做部分实现呢,其实是我只实现一部分的分享。如新浪微博,那还有没去实现的是微信分享。还有一部分奇怪的问题:我QQ分享跟QQ空间的分享功能,我都没配置key那些都是原本集成就有的key也可以实现分享,谁清楚的麻烦详解下。 实现分享功能我们可以去www.mob.com这个网站集成。免费的,而且还有短信验证功能。等这分享研究完后就研究下短信验证功能。 开始实现步骤(新浪分享,以下是本人自己实现

基于Springboot + vue 的抗疫物质管理系统的设计与实现

目录 📚 前言 📑摘要 📑系统流程 📚 系统架构设计 📚 数据库设计 📚 系统功能的具体实现    💬 系统登录注册 系统登录 登录界面   用户添加  💬 抗疫列表展示模块     区域信息管理 添加物资详情 抗疫物资列表展示 抗疫物资申请 抗疫物资审核 ✒️ 源码实现 💖 源码获取 😁 联系方式 📚 前言 📑博客主页:

探索蓝牙协议的奥秘:用ESP32实现高质量蓝牙音频传输

蓝牙(Bluetooth)是一种短距离无线通信技术,广泛应用于各种电子设备之间的数据传输。自1994年由爱立信公司首次提出以来,蓝牙技术已经经历了多个版本的更新和改进。本文将详细介绍蓝牙协议,并通过一个具体的项目——使用ESP32实现蓝牙音频传输,来展示蓝牙协议的实际应用及其优点。 蓝牙协议概述 蓝牙协议栈 蓝牙协议栈是蓝牙技术的核心,定义了蓝牙设备之间如何进行通信。蓝牙协议

python实现最简单循环神经网络(RNNs)

Recurrent Neural Networks(RNNs) 的模型: 上图中红色部分是输入向量。文本、单词、数据都是输入,在网络里都以向量的形式进行表示。 绿色部分是隐藏向量。是加工处理过程。 蓝色部分是输出向量。 python代码表示如下: rnn = RNN()y = rnn.step(x) # x为输入向量,y为输出向量 RNNs神经网络由神经元组成, python

利用Frp实现内网穿透(docker实现)

文章目录 1、WSL子系统配置2、腾讯云服务器安装frps2.1、创建配置文件2.2 、创建frps容器 3、WSL2子系统Centos服务器安装frpc服务3.1、安装docker3.2、创建配置文件3.3 、创建frpc容器 4、WSL2子系统Centos服务器安装nginx服务 环境配置:一台公网服务器(腾讯云)、一台笔记本电脑、WSL子系统涉及知识:docker、Frp

基于 Java 实现的智能客服聊天工具模拟场景

服务端代码 import java.io.BufferedReader;import java.io.IOException;import java.io.InputStreamReader;import java.io.PrintWriter;import java.net.ServerSocket;import java.net.Socket;public class Serv