在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差)

2024-03-10 09:08

本文主要是介绍在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个内容是将随便做了一个网络结构,然后简单的训练几次,生成模型,并且存储起来,主要是为了学习获得pytorch中的BatchNorm2d层的各个特征图的平均值和方差。代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.optim import lr_scheduler
import torch.optim as optimclass VGG(nn.Module):def __init__(self):super(VGG,self).__init__()self.conv1 = nn.Conv2d(3,64,3,padding=(1,1))self.bn1 = nn.BatchNorm2d(64)self.maxpool1 = nn.MaxPool2d((2,2))self.conv2 = nn.Conv2d(64,128,3,padding=(1,1))# self.bn2 = nn.BatchNorm2d(128)self.maxpool2 = nn.MaxPool2d((2,2))self.conv3 = nn.Conv2d(128,256,3,padding=(1,1))# self.bn3 = nn.BatchNorm2d(256)self.maxpool3 = nn.MaxPool2d((2,2))self.fc1 = nn.Linear(256*16*8,4096)self.fc2 = nn.Linear(4096,1000)self.fc3 = nn.Linear(1000,10)def forward(self,x):in_size = x.size(0)out = self.conv1(x)out = self.bn1(out)out = F.relu(out)out = self.maxpool1(out)out = self.conv2(out)out = F.relu(out)out = self.maxpool2(out)out = self.conv3(out)out = F.relu(out)out = self.maxpool3(out)out = out.view(out.size(0),-1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.relu(out)out = self.fc3(out)return outtransform_train_list = transforms.Compose([transforms.Resize( (256,128),interpolation=3 ),transforms.RandomCrop((128,64)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])train_dataset = datasets.ImageFolder('./train',transform_train_list)
dataloaders = torch.utils.data.DataLoader(train_dataset,batch_size=2,num_workers=0)dataset_size = len(train_dataset)
class_names = train_dataset.classesprint(dataset_size)
print(class_names)
net=VGG()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
criterion = nn.CrossEntropyLoss()for epoch in range(2):#训练模型print(epoch)net.train(True)running_loss = 0.0running_corrects = 0.0for data in dataloaders:inputs,labels = datanow_batch,c,h,w = inputs.shapeoptimizer.zero_grad()outputs = net(inputs)# print(outputs)_,preds = torch.max(outputs.data,1)loss = criterion(outputs,labels)loss.backward()optimizer.step()running_loss = running_loss + loss.item() * now_batchrunning_corrects += float( torch.sum( preds == labels.data ) )epoch_loss = running_loss/dataset_sizeepoch_acc = running_corrects/dataset_sizeprint(epoch_loss)print(epoch_acc)torch.save(net.cpu().state_dict(),'first.pth')  ##将训练好的模型保存起来net = VGG()
net.load_state_dict( torch.load('first.pth') )
net.eval()  #产生一个模型并且加载已经训练好的模型的参数# for data in dataloaders:
#     inputs,labels = data
#     # print(inputs)
#     print(labels)
#     outputs = net(inputs)
#     print(outputs)
#     breakm = VGG()
# m.eval()
m.load_state_dict( torch.load('second.pth') )
print(m.bn1.running_mean.size()) ##获得一共有多少个mean  要是想获得var只要将mean改为var即可
print(m.bn1.running_mean.data[0])
print(m.bn1.running_mean.data[1])
print(m.bn1.running_mean.data[2])print(m.bn1.running_var.data[0])
print(type(m.bn1.running_mean.data[0]))m.bn1.running_mean.data[0] = m.bn1.running_mean.data[2]  ##可以对模型参数进行更改,然后保存更改后的模型
m.bn1.running_mean.data[1] = m.bn1.running_mean.data[2]torch.save(m.cpu().state_dict(),'second.pth')

对于输入到BatchNorm2d层的数据格式为(batch_size,channels_size,h,w),channels_size为多少,就会生成多少个mean和var。

举个例子,如果输入的数据是batch_size=16,channels_size=64,h=32,w=16,则每对mean和var都是 16张某一个特征图中的所有数据的mean和var

这篇关于在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/793745

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应