spikingjelly学习-训练网络

2024-04-10 04:04

本文主要是介绍spikingjelly学习-训练网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【MNIST数据集包含若干尺寸为28*28的8位灰度图像,总共有0~9共10个类别。以MNIST的分类为例,一个简单的单层ANN网络如下

我们也可以用完全类似结构的SNN来进行分类任务。就这个网络而言,只需要先去掉所有的激活函数,再将尖峰神经元添加到原来激活函数的位置,这里我们选择的是LIF神经元。神经元之间的连接层需要用
spikingjelly.activation_based.layer包装:

在 spikingjelly 中,我们约定,只能输出脉冲,即0或1的神经元,都可以称之为“脉冲神经元”。使用脉冲神经元的网络,进而也可以称之为脉冲神经元网络(Spiking Neural Networks, SNNs)。这里使用了 neuron.IFNode() 来构建 IF 神经元层,该神经元层有如下构造函数:
  1. v_threshold – 神经元的阈值电压
  2. v_reset – 神经元的重置电压。
  3. surrogate_function – 反向传播时用来计算脉冲函数梯度的替代函数
    神经元的数量是在初始化或调用 reset() 函数重新初始化后,根据第一次接收的输入的 shape 自动决定的。此处则是10个神经元。其中膜电位衰减常数 需要通过参数tau设置,替代函数这里选择surrogate.ATan。
    然后是训练SNN网络,指定好训练参数如学习率等以及若干其他配置优化器默认使用Adam,以及使用泊松编码器,在每次输入图片时进行脉冲编码。

【训练代码的编写需要遵循以下三个要点:
 脉冲神经元的输出是二值的,而直接将单次运行的结果用于分类极易受到编码带来的噪声干扰。因此一般认为脉冲网络的输出是输出层一段时间内的发放频率(或称发放率),发放率的高低表示该类别的响应大小。因此网络需要运行一段时间,即使用T个时刻后的平均发放率作为分类依据。
 我们希望的理想结果是除了正确的神经元以最高频率发放,其他神经元保持静默。常常采用交叉熵损失或者MSE损失,这里我们使用实际效果更好的MSE损失。
 每次网络仿真结束后,需要重置网络状态

 # 保存绘图用数据net.eval()# 注册钩子output_layer = net.layer[-1] # 输出层output_layer.v_seq = []output_layer.s_seq = []def save_hook(m, x, y):m.v_seq.append(m.v.unsqueeze(0))m.s_seq.append(y.unsqueeze(0))output_layer.register_forward_hook(save_hook)with torch.no_grad():img, label = test_dataset[0]img = img.to(args.device)out_fr = 0.for t in range(args.T):encoded_img = encoder(img)out_fr += net(encoded_img)out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()print(f'Firing rate: {out_spikes_counter_frequency}')output_layer.v_seq = torch.cat(output_layer.v_seq)output_layer.s_seq = torch.cat(output_layer.s_seq)v_t_array = output_layer.v_seq.cpu().numpy().squeeze()  # v_t_array[i][j]表示神经元i在j时刻的电压值np.save("v_t_array.npy",v_t_array)s_t_array = output_layer.s_seq.cpu().numpy().squeeze()  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1np.save("s_t_array.npy",s_t_array)

在这里插入图片描述
【在PyTorch中,钩子(hooks)是一种强大的工具,允许你在模型的前向传播(forward pass)或反向传播(backward pass)过程中插入自定义操作。这些操作可以用于调试、可视化、保存中间状态等目的,而不需要修改模型的定义。
钩子的类型
前向钩子(Forward Hooks):在层的前向传播执行完毕后立即执行。它们通常用于检查、修改或记录从层输出的数据。
反向钩子(Backward Hooks):在层的梯度计算过程中执行。它们用于检查或修改梯度值。
这段代码中的钩子使用
在提供的代码段中,使用了一个前向钩子(save_hook)来保存神经网络某层在前向传播过程中的电压值(v)和脉冲值(s)。
这个钩子函数save_hook接收三个参数:
m:注册钩子的模块(在这个例子中是输出层)。
x:输入到该模块的数据。
y:从该模块输出的数据。
在钩子函数内部,它将模块m的电压值v和输出脉冲y保存到列表中。这里使用unsqueeze(0)是为了增加一个批次维度,使得每次迭代的数据可以被堆叠起来。
钩子的注册
这行代码将save_hook函数注册为output_layer(网络的最后一层)的前向钩子。这意味着每当output_layer完成前向传播时,save_hook函数都会被调用。
数据的保存
在所有测试图像通过网络并且钩子函数被调用之后,v_seq和s_seq列表中的数据被合并(使用torch.cat)并转换为NumPy数组,然后通过np.save保存到文件中。这些文件包含了在整个测试集上,输出层神经元的电压值和脉冲发放情况,可以用于进一步的分析和可视化。】
这段代码通过注册一个前向钩子来捕获并保存神经网络最后一层在前向传播过程中的电压和脉冲数据。这种方法非常有用,因为它允许在不修改网络结构的情况下收集内部状态信息,对于理解和分析网络的行为非常有帮助。

这篇关于spikingjelly学习-训练网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/890024

相关文章

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五