Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇

本文主要是介绍Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 1.数据处理部分
      • 1.1 下载数据集
      • 1.2 数据集预处理
      • 1.3 划分train-val并创建Dataset对象
      • 1.4 掩码mask操作

数据、评估标准见NLB2021
https://neurallatents.github.io/

以下代码依据
https://github.com/trungle93/STNDT

原代码使用了 Ray+Config文件进行了参数搜索,库依赖较多,数据流过程不明显,代码冗杂,这里进行了抽丝剥茧,将其中最核心的部分提取出来。

1.数据处理部分

1.1 下载数据集

需要依赖 pip install dandi
downald.py

root = "D:/NeuralLatent/"
def downald_data():from dandi.download import downloaddownload("https://dandiarchive.org/dandiset/000128", root)download("https://dandiarchive.org/dandiset/000138", root)download("https://dandiarchive.org/dandiset/000139", root)download("https://dandiarchive.org/dandiset/000140", root)download("https://dandiarchive.org/dandiset/000129", root)download("https://dandiarchive.org/dandiset/000127", root)download("https://dandiarchive.org/dandiset/000130", root)

1.2 数据集预处理

需要依赖官方工具包pip install nlb_tools
主要是加载锋值序列数据,将其采样为5ms的时间槽
preprocess.py

## 以下为参数示例
# data_path = root + "/000129/sub-Indy/"
# dataset_name = "mc_rtt"
## 注意 "./data" 必须提前创建好from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, combine_h5def preprocess(data_path, dataset_name=None):dataset = NWBDataset(datapath)bin_width = 5dataset.resample(bin_width)make_train_input_tensors(dataset, dataset_name=dataset_name, trial_split="train", include_behavior=True, include_forward_pred=True, save_file=True,save_path=f"./data/{dataset_name}_train.h5")make_eval_input_tensors(dataset, dataset_name=dataset_name, trial_split="val", save_file=True, save_path=f"./data/{dataset_name}_val.h5")combine_h5([f"./data/{dataset_name}_train.h5", f"./data/{dataset_name}_val.h5"], save_path=f"./data/{dataset_name}_full.h5")## './data/mc_rtt_full.h5' 将成为后续的主要分析数据

1.3 划分train-val并创建Dataset对象

读取'./data/mc_rtt_full.h5'中的数据并创建dataset
dataset.py

import h5py
import numpy as np
import torch
from torch.utils import data
# data_path = "./data/mc_rtt_full.h5"class SpikesDataset(data.Dataset):def __init__(self, spikes, heldout_spikes, forward_spikes) -> None:self.spikes = spikesself.heldout_spikes = heldout_spikesself.forward_spikes = forward_spikesdef __len__(self):return self.spikes.size(0)def __getitem__(self, index):r"""Return spikes and rates, shaped T x N (num_neurons)"""return self.spikes[index], self.heldout_spikes[index], self.forward_spikes[index]def make_datasets(data_path):with h5py.File(data_path, 'r') as h5file:h5dict = {key: h5file[key][()] for key in h5file.keys()}if 'eval_spikes_heldin' in h5dict: # NLB dataget_key = lambda key: h5dict[key].astype(np.float32)train_data = get_key('train_spikes_heldin')train_data_fp = get_key('train_spikes_heldin_forward')train_data_heldout_fp = get_key('train_spikes_heldout_forward')train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)valid_data = get_key('eval_spikes_heldin')train_data_heldout = get_key('train_spikes_heldout')if 'eval_spikes_heldout' in h5dict:valid_data_heldout = get_key('eval_spikes_heldout')else:valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)if 'eval_spikes_heldin_forward' in h5dict:valid_data_fp = get_key('eval_spikes_heldin_forward')valid_data_heldout_fp = get_key('eval_spikes_heldout_forward')valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)else:valid_data_all_fp = np.zeros((valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32)train_dataset = SpikesDataset(torch.tensor(train_data).long(),            # [810, 120, 98]torch.tensor(train_data_heldout).long(),    # [810, 120, 32]torch.tensor(train_data_all_fp).long(),     # [810, 40, 130])val_dataset = SpikesDataset(torch.tensor(valid_data).long(),            # [810, 120, 98]torch.tensor(valid_data_heldout).long(),    # [810, 120, 32]torch.tensor(valid_data_all_fp).long(),     # [810, 40, 130])return train_dataset, val_dataset

1.4 掩码mask操作

dataset.py

# Some infeasibly high spike count
UNMASKED_LABEL = -100def mask_batch(batch, heldout_spikes, forward_spikes):batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)mask_ratio = 0.31254mask_random_ratio = 0.876mask_token_ratio = 0.527labels = batch.clone()mask_probs = torch.full(labels.shape, mask_ratio)# If we want any tokens to not get masked, do it here (but we don't currently have any)mask = torch.bernoulli(mask_probs)mask = mask.bool()labels[~mask] = UNMASKED_LABEL  # No ground truth for unmasked - use this to mask loss# We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context# Most times, we replace tokens with MASK tokenindices_replaced = torch.bernoulli(torch.full(labels.shape, mask_token_ratio)).bool() & maskbatch[indices_replaced] = 0# Random % of the time, we replace masked input tokens with random value (the rest are left intact)indices_random = torch.bernoulli(torch.full(labels.shape, mask_random_ratio)).bool() & mask & ~indices_replacedrandom_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long)batch[indices_random] = random_spikes[indices_random]# heldout spikes are all maskedbatch = torch.cat([batch, torch.zeros_like(heldout_spikes)], -1)labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)batch = torch.cat([batch, torch.zeros_like(forward_spikes)], 1)labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)# Leave the other 10% alonereturn batch, labels

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906187

这篇关于Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088377

相关文章

C语言中联合体union的使用

本文编辑整理自: http://bbs.chinaunix.net/forum.php?mod=viewthread&tid=179471 一、前言 “联合体”(union)与“结构体”(struct)有一些相似之处。但两者有本质上的不同。在结构体中,各成员有各自的内存空间, 一个结构变量的总长度是各成员长度之和。而在“联合”中,各成员共享一段内存空间, 一个联合变量

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

LangChain转换链:让数据处理更精准

1. 转换链的概念 在开发AI Agent(智能体)时,我们经常需要对输入数据进行预处理,这样可以更好地利用LLM。LangChain提供了一个强大的工具——转换链(TransformChain),它可以帮我们轻松实现这一任务。 转换链(TransformChain)主要是将 给定的数据 按照某个函数进行转换,再将 转换后的结果 输出给LLM。 所以转换链的核心是:根据业务逻辑编写合适的转换函

Tolua使用笔记(上)

目录   1.准备工作 2.运行例子 01.HelloWorld:在C#中,创建和销毁Lua虚拟机 和 简单调用。 02.ScriptsFromFile:在C#中,对一个lua文件的执行调用 03.CallLuaFunction:在C#中,对lua函数的操作 04.AccessingLuaVariables:在C#中,对lua变量的操作 05.LuaCoroutine:在Lua中,

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

Vim使用基础篇

本文内容大部分来自 vimtutor,自带的教程的总结。在终端输入vimtutor 即可进入教程。 先总结一下,然后再分别介绍正常模式,插入模式,和可视模式三种模式下的命令。 目录 看完以后的汇总 1.正常模式(Normal模式) 1.移动光标 2.删除 3.【:】输入符 4.撤销 5.替换 6.重复命令【. ; ,】 7.复制粘贴 8.缩进 2.插入模式 INSERT