TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码

本文主要是介绍TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码
脑电信号采集设备是由NT9200-32D型号脑电图仪和NeuSen W系列无线脑电采集系统组成,采集后的信号用Matlab打开,保存在结构体数据中,采集到的原始信号形式是:16x640000 double,最开始对数据进行手动分段分成[280,16,1000],280指trials,22指channels,1000指 samples,
整个代码可分为:**数据切分,搭建网络,训练数据,测试数据,**四个部分
1.导入包

import numpy as np
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
import scipy.io
from matplotlib import pyplot as plt

2.数据切分

K.set_image_data_format('channels_last')
samplesfile = scipy.io.loadmat('F:/holiday_code/attention/TSA/data/foursecond.mat')
X = samplesfile['eeg']#提取数组,结构体名称是eeg
event_id = dict(l=1, m=2, lm=3, ml=4)#四分类运动想象数据
# Setup for reading the raw data
labels = samplesfile['Mark']#加载标签数据
y = labels[:,-1]#标签数据
kernels, chans, samples = 1, 16, 1000# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:140, ]
Y_train = y[0:140]
X_validate = X[140:210, ]
Y_validate = y[140:210]
X_test = X[210:, ]
Y_test = y[210:]
#把标签数据转换成one-hot编码
Y_train = np_utils.to_categorical(Y_train - 1)
Y_validate = np_utils.to_categorical(Y_validate - 1)
Y_test = np_utils.to_categorical(Y_test - 1)
#根据网络结构设置数据的输入形式(trials, channels, samples, kernels)
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

4.搭建网络

#导入需要的库
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
def EEGNet(nb_classes, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):if dropoutType == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropoutType == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropoutType must be one of SpatialDropout2D ''or Dropout, passed as a string.')input1 = Input(shape = (Chans, Samples, 1))print("input shape", input1.shape, Chans, Samples, kernLength)##################################################################block1 = Conv2D(F1, (1, kernLength), padding = 'same',input_shape = (Chans, Samples, 1),use_bias = False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias = False,depth_multiplier = D,depthwise_constraint = max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)block2 = SeparableConv2D(F2, (1, 16),use_bias = False, padding = 'same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)flatten = Flatten(name = 'flatten')(block2)dense = Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate))(flatten)softmax = Activation('softmax', name = 'softmax')(dense)return Model(inputs=input1, outputs=softmax)

5.训练模型

model = EEGNet(nb_classes = 4, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,dropoutType = 'Dropout')
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5', verbose=1,save_best_only=True)
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,verbose=2, validation_data=(X_validate, Y_validate),callbacks=[checkpointer], class_weight=class_weights)

6.测试模型

model.load_weights('F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5')
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))# plot the accuracy and loss graph
plt.plot(fittedModel.history['accuracy'])
plt.plot(fittedModel.history['val_accuracy'])
plt.plot(fittedModel.history['loss'])
plt.plot(fittedModel.history['val_loss'])
plt.title('acc & loss')
plt.xlabel('epoch')
plt.legend(['acc', 'val_acc','loss','val_loss'], loc='upper right')
plt.show()

7.分类结果
在这里插入图片描述
整个网络框架大概就是这样,这是其中一个被试的分类结果,属于分类效果比较好的,其他被试可能由于数据质量,网络结构等原因分类效果不是很理想,考虑数据增强以及网络结构优化去提高分类准确率。

这篇关于TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/224835

相关文章

Spring Boot @RestControllerAdvice全局异常处理最佳实践

《SpringBoot@RestControllerAdvice全局异常处理最佳实践》本文详解SpringBoot中通过@RestControllerAdvice实现全局异常处理,强调代码复用、统... 目录前言一、为什么要使用全局异常处理?二、核心注解解析1. @RestControllerAdvice2

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Java进程异常故障定位及排查过程

《Java进程异常故障定位及排查过程》:本文主要介绍Java进程异常故障定位及排查过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、故障发现与初步判断1. 监控系统告警2. 日志初步分析二、核心排查工具与步骤1. 进程状态检查2. CPU 飙升问题3. 内存

SpringBoot整合liteflow的详细过程

《SpringBoot整合liteflow的详细过程》:本文主要介绍SpringBoot整合liteflow的详细过程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋...  liteflow 是什么? 能做什么?总之一句话:能帮你规范写代码逻辑 ,编排并解耦业务逻辑,代码

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

MySQL中的InnoDB单表访问过程

《MySQL中的InnoDB单表访问过程》:本文主要介绍MySQL中的InnoDB单表访问过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、环境3、访问类型【1】const【2】ref【3】ref_or_null【4】range【5】index【6】