PaddlePaddle笔记5-心电图智能诊断

2023-10-23 03:40

本文主要是介绍PaddlePaddle笔记5-心电图智能诊断,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

睿洛医疗

参考资料:

AIWIN - 比赛

AIWIN 心电图智能诊断竞赛:TextCNN 精度0.67 - 飞桨AI Studio - 人工智能学习与实训社区

        心电图是临床最基础的一个检查项目,因为安全、便捷成为心脏病诊断的利器。每天都有大量的心电图诊断需求,但是全国范围内诊断心电图的专业医生数量不足,导致很多医院都面临专业心电图医生短缺的情况。人工智能技术的出现,为改善医生人力资源不足的问题带来了全新的可能。由于心电图数据与诊断的标准化程度较高,相对较易于运用人工智能技术进行智能诊断算法的开发。由于心电图可诊断的疾病类别特别丰富,目前,市面上出现较多的是针对某些特定类别的算法,尚没有看到能够按照临床诊断标准、在一定准确率标准下,提供类似医生的多标签多分类算法。

        需要识别的心电图包括12个类别:正常心电图、窦性心动过缓、窦性心动过速、窦性心律不齐、心房颤动、室性早搏、房性早搏、一度房室阻滞、完全性右束支阻滞、T波改变、ST改变、其它。

         心电数据的单位为mV,采样率为 500HZ,记录时长为 10 秒,存储格式为 MAT;文件中存储了 12 导联的电压信号(包含了I,II,III,aVR,aVL,aVF,V1,V2,V3,V4,V5 和 V6)。数据下载地址:心电图测试数据集https://download.csdn.net/download/zj850324/67399060https://download.csdn.net/download/zj850324/67399060

数据格式:

12导联的数据,保存matlab格式文件中。数据格式是(12, 5000);

采样500HZ,10S长度有效数据;

0..12是I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5和V6数据,单位是mV。

1,准备模块(依赖库按需安装即可,不再赘述)

import scipy.io as sio
import matplotlib.pyplot as plt
import codecs, glob, os
import numpy as np
import pandas as pd
import paddle
import paddle.nn as nn
from paddle.io import DataLoader, Dataset
import paddle.optimizer as optim
from paddlenlp.data import Pad
import scipy.io as sio
from sklearn.model_selection import StratifiedKFold

2,构造数据集

class MyDataset(Dataset):def __init__(self, mat, label, mat_dim=3000):super(MyDataset, self).__init__()self.mat = matself.label = labelself.mat_dim = mat_dimdef __len__(self):return len(self.mat)def __getitem__(self, index):idx = np.random.randint(0, 5000-self.mat_dim)# idy = np.random.choice(range(12), 9)inputs=paddle.to_tensor(self.mat[index][:, :, idx:idx+self.mat_dim])label=paddle.to_tensor(self.label[index])return inputs,label

3,构造模型

class TextCNN_Plus(nn.Layer):def __init__(self, kernel_num=30, kernel_size=[3, 4, 5], dropout=0.5,mat_dim=3000):super(TextCNN_Plus, self).__init__()self.kernel_num = kernel_numself.kernel_size = kernel_sizeself.dropout = dropoutself.convs = nn.LayerList([nn.Conv2D(1, self.kernel_num, (kernel_size_, mat_dim))for kernel_size_ in self.kernel_size])self.dropout = nn.Dropout(self.dropout)self.linear = nn.Linear(3 * self.kernel_num, 1)def forward(self, x):convs = [nn.ReLU()(conv(x)).squeeze(3) for conv in self.convs]pool_out = [nn.MaxPool1D(block.shape[2])(block).squeeze(2) for block in convs]pool_out = paddle.concat(pool_out, 1)logits = self.linear(pool_out)return logits

4,加载数据集

def load_data(BATCH_SIZE):train_mat = glob.glob('./data/ecg/train/*.mat')train_mat.sort()train_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in train_mat]test_mat = glob.glob('./data/ecg/val/*.mat')test_mat.sort()test_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in test_mat]train_df = pd.read_csv('./data/ecg/trainreference.csv')train_df['tag'] = train_df['tag'].astype(np.float32)# 查看数据plt.plot(range(5000), train_mat[0][0][0])plt.plot(range(5000), train_mat[0][0][1])plt.plot(range(5000), train_mat[0][0][3])# plt.show()train_df.head()print(test_mat[0].shape)train_ds = MyDataset(train_mat, train_df['tag'])train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)for batch in train_loader:print(batch)breakreturn train_mat,train_df,test_mat

5,定义模型网络

def model_create(mat_dim):# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)paddle.summary(model, (64, 1, 9, mat_dim))return model

6,模型训练

def model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir):train_mat, train_df, test_mat = load_data(BATCH_SIZE)skf = StratifiedKFold(n_splits=num_splits)if (not os.path.exists(output_dir)):os.mkdir(output_dir)fold_idx = 0for tr_idx, val_idx in skf.split(train_mat, train_df['tag'].values):train_ds = MyDataset(np.array(train_mat)[tr_idx], train_df['tag'].values[tr_idx], mat_dim=mat_dim)dev_ds = MyDataset(np.array(train_mat)[val_idx], train_df['tag'].values[val_idx], mat_dim=mat_dim)Train_Loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)Val_Loader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=True)# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)optimizer = optim.Adam(parameters=model.parameters(), learning_rate=LEARNING_RATE)criterion = nn.BCEWithLogitsLoss()Test_best_Acc = 0for epoch in range(0, EPOCHS):Train_Loss, Test_Loss = [], []Train_Acc, Test_Acc = [], []model.train()for i, (x, y) in enumerate(Train_Loader):pred = model(x)loss = criterion(pred, y)Train_Loss.append(loss.item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Train_Acc.append((pred.numpy() == y.numpy()).mean())loss.backward()optimizer.step()optimizer.clear_grad()model.eval()for i, (x, y) in enumerate(Val_Loader):pred = model(x)Test_Loss.append(criterion(pred, y).item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Test_Acc.append((pred.numpy() == y.numpy()).mean())if epoch % 10 == 0:print("Epoch: [{}/{}] TrainLoss/TestLoss: {:.4f}/{:.4f} TrainAcc/TestAcc: {:.4f}/{:.4f}".format( \epoch + 1, EPOCHS, \np.mean(Train_Loss), np.mean(Test_Loss), \np.mean(Train_Acc), np.mean(Test_Acc) \))if Test_best_Acc < np.mean(Test_Acc):print(f'Fold {fold_idx} Acc imporve from {Test_best_Acc} to {np.mean(Test_Acc)} Save Model...')paddle.save(model.state_dict(), os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))Test_best_Acc = np.mean(Test_Acc)fold_idx += 1

 7,载入模型预测(得先训练)

def model_predict(modelpath,mat_dim,answer_path):train_mat, train_df, test_mat = load_data(BATCH_SIZE)test_perd = np.zeros(len(test_mat))tta_count = 20model = model_create(mat_dim)layer_state_dict = paddle.load(modelpath + '/model_4.pdparams')model.set_state_dict(layer_state_dict)for fold_idx in range(num_splits):test_ds = MyDataset(test_mat, [0] * len(test_mat), mat_dim=mat_dim)Test_Loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)layer_state_dict = paddle.load(os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))model.set_state_dict(layer_state_dict)for tta in range(tta_count):test_pred_list = []for i, (x, y) in enumerate(Test_Loader):pred = model(x)test_pred_list.append(paddle.nn.functional.sigmoid(pred).numpy())test_perd += np.vstack(test_pred_list)[:, 0]test_perd /= tta_count * num_splits# 生成结果test_path = glob.glob('./data/ecg/val/*.mat')test_path = [os.path.basename(x)[:-4] for x in test_path]test_path.sort()test_answer = pd.DataFrame({'name': test_path,'tag': (test_perd > 0.5).astype(int)}).to_csv(answer_path, index=None)

8,预测结果

备注

调试步骤及临时配参

#### 测试
EPOCHS = 200
BATCH_SIZE = 30
LEARNING_RATE = 0.0005
num_splits = 5
mat_dim = 4000
output_dir = 'checkpoint'
answer_path = './data/ecg/answer.csv'# model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir)
model_predict(output_dir,mat_dim,answer_path)

这篇关于PaddlePaddle笔记5-心电图智能诊断的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/265451

相关文章

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

智能交通(二)——Spinger特刊推荐

特刊征稿 01  期刊名称: Autonomous Intelligent Systems  特刊名称: Understanding the Policy Shift  with the Digital Twins in Smart  Transportation and Mobility 截止时间: 开放提交:2024年1月20日 提交截止日

基于 YOLOv5 的积水检测系统:打造高效智能的智慧城市应用

在城市发展中,积水问题日益严重,特别是在大雨过后,积水往往会影响交通甚至威胁人们的安全。通过现代计算机视觉技术,我们能够智能化地检测和识别积水区域,减少潜在危险。本文将介绍如何使用 YOLOv5 和 PyQt5 搭建一个积水检测系统,结合深度学习和直观的图形界面,为用户提供高效的解决方案。 源码地址: PyQt5+YoloV5 实现积水检测系统 预览: 项目背景

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个

查看提交历史 —— Git 学习笔记 11

查看提交历史 查看提交历史 不带任何选项的git log-p选项--stat 选项--pretty=oneline选项--pretty=format选项git log常用选项列表参考资料 在提交了若干更新,又或者克隆了某个项目之后,你也许想回顾下提交历史。 完成这个任务最简单而又有效的 工具是 git log 命令。 接下来的例子会用一个用于演示的 simplegit