TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码

本文主要是介绍TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码
脑电信号采集设备是由NT9200-32D型号脑电图仪和NeuSen W系列无线脑电采集系统组成,采集后的信号用Matlab打开,保存在结构体数据中,采集到的原始信号形式是:16x640000 double,最开始对数据进行手动分段分成[280,16,1000],280指trials,22指channels,1000指 samples,
整个代码可分为:**数据切分,搭建网络,训练数据,测试数据,**四个部分
1.导入包

import numpy as np
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
import scipy.io
from matplotlib import pyplot as plt

2.数据切分

K.set_image_data_format('channels_last')
samplesfile = scipy.io.loadmat('F:/holiday_code/attention/TSA/data/foursecond.mat')
X = samplesfile['eeg']#提取数组,结构体名称是eeg
event_id = dict(l=1, m=2, lm=3, ml=4)#四分类运动想象数据
# Setup for reading the raw data
labels = samplesfile['Mark']#加载标签数据
y = labels[:,-1]#标签数据
kernels, chans, samples = 1, 16, 1000# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:140, ]
Y_train = y[0:140]
X_validate = X[140:210, ]
Y_validate = y[140:210]
X_test = X[210:, ]
Y_test = y[210:]
#把标签数据转换成one-hot编码
Y_train = np_utils.to_categorical(Y_train - 1)
Y_validate = np_utils.to_categorical(Y_validate - 1)
Y_test = np_utils.to_categorical(Y_test - 1)
#根据网络结构设置数据的输入形式(trials, channels, samples, kernels)
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

4.搭建网络

#导入需要的库
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
def EEGNet(nb_classes, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):if dropoutType == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropoutType == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropoutType must be one of SpatialDropout2D ''or Dropout, passed as a string.')input1 = Input(shape = (Chans, Samples, 1))print("input shape", input1.shape, Chans, Samples, kernLength)##################################################################block1 = Conv2D(F1, (1, kernLength), padding = 'same',input_shape = (Chans, Samples, 1),use_bias = False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias = False,depth_multiplier = D,depthwise_constraint = max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)block2 = SeparableConv2D(F2, (1, 16),use_bias = False, padding = 'same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)flatten = Flatten(name = 'flatten')(block2)dense = Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate))(flatten)softmax = Activation('softmax', name = 'softmax')(dense)return Model(inputs=input1, outputs=softmax)

5.训练模型

model = EEGNet(nb_classes = 4, Chans = 16, Samples = 1000,dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,dropoutType = 'Dropout')
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5', verbose=1,save_best_only=True)
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,verbose=2, validation_data=(X_validate, Y_validate),callbacks=[checkpointer], class_weight=class_weights)

6.测试模型

model.load_weights('F:/holiday_code/attention/TSA/tmptwo/tmp/checkpoint.h5')
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))# plot the accuracy and loss graph
plt.plot(fittedModel.history['accuracy'])
plt.plot(fittedModel.history['val_accuracy'])
plt.plot(fittedModel.history['loss'])
plt.plot(fittedModel.history['val_loss'])
plt.title('acc & loss')
plt.xlabel('epoch')
plt.legend(['acc', 'val_acc','loss','val_loss'], loc='upper right')
plt.show()

7.分类结果
在这里插入图片描述
整个网络框架大概就是这样,这是其中一个被试的分类结果,属于分类效果比较好的,其他被试可能由于数据质量,网络结构等原因分类效果不是很理想,考虑数据增强以及网络结构优化去提高分类准确率。

这篇关于TensorFlow搭建搭建卷积神经网络EEGNet处理脑电数据过程代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/224835

相关文章

SpringBoot集成Milvus实现数据增删改查功能

《SpringBoot集成Milvus实现数据增删改查功能》milvus支持的语言比较多,支持python,Java,Go,node等开发语言,本文主要介绍如何使用Java语言,采用springboo... 目录1、Milvus基本概念2、添加maven依赖3、配置yml文件4、创建MilvusClient

浅析Java中如何优雅地处理null值

《浅析Java中如何优雅地处理null值》这篇文章主要为大家详细介绍了如何结合Lambda表达式和Optional,让Java更优雅地处理null值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录场景 1:不为 null 则执行场景 2:不为 null 则返回,为 null 则返回特定值或抛出异常场景

深入理解Apache Kafka(分布式流处理平台)

《深入理解ApacheKafka(分布式流处理平台)》ApacheKafka作为现代分布式系统中的核心中间件,为构建高吞吐量、低延迟的数据管道提供了强大支持,本文将深入探讨Kafka的核心概念、架构... 目录引言一、Apache Kafka概述1.1 什么是Kafka?1.2 Kafka的核心概念二、Ka

SpringValidation数据校验之约束注解与分组校验方式

《SpringValidation数据校验之约束注解与分组校验方式》本文将深入探讨SpringValidation的核心功能,帮助开发者掌握约束注解的使用技巧和分组校验的高级应用,从而构建更加健壮和可... 目录引言一、Spring Validation基础架构1.1 jsR-380标准与Spring整合1

MySQL 中查询 VARCHAR 类型 JSON 数据的问题记录

《MySQL中查询VARCHAR类型JSON数据的问题记录》在数据库设计中,有时我们会将JSON数据存储在VARCHAR或TEXT类型字段中,本文将详细介绍如何在MySQL中有效查询存储为V... 目录一、问题背景二、mysql jsON 函数2.1 常用 JSON 函数三、查询示例3.1 基本查询3.2

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

SpringBatch数据写入实现

《SpringBatch数据写入实现》SpringBatch通过ItemWriter接口及其丰富的实现,提供了强大的数据写入能力,本文主要介绍了SpringBatch数据写入实现,具有一定的参考价值,... 目录python引言一、ItemWriter核心概念二、数据库写入实现三、文件写入实现四、多目标写入

Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码

《Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码》:本文主要介绍Java中日期时间转换的多种方法,包括将Date转换为LocalD... 目录一、Date转LocalDateTime二、Date转LocalDate三、LocalDateTim

使用Python将JSON,XML和YAML数据写入Excel文件

《使用Python将JSON,XML和YAML数据写入Excel文件》JSON、XML和YAML作为主流结构化数据格式,因其层次化表达能力和跨平台兼容性,已成为系统间数据交换的通用载体,本文将介绍如何... 目录如何使用python写入数据到Excel工作表用Python导入jsON数据到Excel工作表用

Mysql如何将数据按照年月分组的统计

《Mysql如何将数据按照年月分组的统计》:本文主要介绍Mysql如何将数据按照年月分组的统计方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录mysql将数据按照年月分组的统计要的效果方案总结Mysql将数据按照年月分组的统计要的效果方案① 使用 DA