PaddlePaddle笔记5-心电图智能诊断

2023-10-23 03:40

本文主要是介绍PaddlePaddle笔记5-心电图智能诊断,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

睿洛医疗

参考资料:

AIWIN - 比赛

AIWIN 心电图智能诊断竞赛:TextCNN 精度0.67 - 飞桨AI Studio - 人工智能学习与实训社区

        心电图是临床最基础的一个检查项目,因为安全、便捷成为心脏病诊断的利器。每天都有大量的心电图诊断需求,但是全国范围内诊断心电图的专业医生数量不足,导致很多医院都面临专业心电图医生短缺的情况。人工智能技术的出现,为改善医生人力资源不足的问题带来了全新的可能。由于心电图数据与诊断的标准化程度较高,相对较易于运用人工智能技术进行智能诊断算法的开发。由于心电图可诊断的疾病类别特别丰富,目前,市面上出现较多的是针对某些特定类别的算法,尚没有看到能够按照临床诊断标准、在一定准确率标准下,提供类似医生的多标签多分类算法。

        需要识别的心电图包括12个类别:正常心电图、窦性心动过缓、窦性心动过速、窦性心律不齐、心房颤动、室性早搏、房性早搏、一度房室阻滞、完全性右束支阻滞、T波改变、ST改变、其它。

         心电数据的单位为mV,采样率为 500HZ,记录时长为 10 秒,存储格式为 MAT;文件中存储了 12 导联的电压信号(包含了I,II,III,aVR,aVL,aVF,V1,V2,V3,V4,V5 和 V6)。数据下载地址:心电图测试数据集https://download.csdn.net/download/zj850324/67399060https://download.csdn.net/download/zj850324/67399060

数据格式:

12导联的数据,保存matlab格式文件中。数据格式是(12, 5000);

采样500HZ,10S长度有效数据;

0..12是I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5和V6数据,单位是mV。

1,准备模块(依赖库按需安装即可,不再赘述)

import scipy.io as sio
import matplotlib.pyplot as plt
import codecs, glob, os
import numpy as np
import pandas as pd
import paddle
import paddle.nn as nn
from paddle.io import DataLoader, Dataset
import paddle.optimizer as optim
from paddlenlp.data import Pad
import scipy.io as sio
from sklearn.model_selection import StratifiedKFold

2,构造数据集

class MyDataset(Dataset):def __init__(self, mat, label, mat_dim=3000):super(MyDataset, self).__init__()self.mat = matself.label = labelself.mat_dim = mat_dimdef __len__(self):return len(self.mat)def __getitem__(self, index):idx = np.random.randint(0, 5000-self.mat_dim)# idy = np.random.choice(range(12), 9)inputs=paddle.to_tensor(self.mat[index][:, :, idx:idx+self.mat_dim])label=paddle.to_tensor(self.label[index])return inputs,label

3,构造模型

class TextCNN_Plus(nn.Layer):def __init__(self, kernel_num=30, kernel_size=[3, 4, 5], dropout=0.5,mat_dim=3000):super(TextCNN_Plus, self).__init__()self.kernel_num = kernel_numself.kernel_size = kernel_sizeself.dropout = dropoutself.convs = nn.LayerList([nn.Conv2D(1, self.kernel_num, (kernel_size_, mat_dim))for kernel_size_ in self.kernel_size])self.dropout = nn.Dropout(self.dropout)self.linear = nn.Linear(3 * self.kernel_num, 1)def forward(self, x):convs = [nn.ReLU()(conv(x)).squeeze(3) for conv in self.convs]pool_out = [nn.MaxPool1D(block.shape[2])(block).squeeze(2) for block in convs]pool_out = paddle.concat(pool_out, 1)logits = self.linear(pool_out)return logits

4,加载数据集

def load_data(BATCH_SIZE):train_mat = glob.glob('./data/ecg/train/*.mat')train_mat.sort()train_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in train_mat]test_mat = glob.glob('./data/ecg/val/*.mat')test_mat.sort()test_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in test_mat]train_df = pd.read_csv('./data/ecg/trainreference.csv')train_df['tag'] = train_df['tag'].astype(np.float32)# 查看数据plt.plot(range(5000), train_mat[0][0][0])plt.plot(range(5000), train_mat[0][0][1])plt.plot(range(5000), train_mat[0][0][3])# plt.show()train_df.head()print(test_mat[0].shape)train_ds = MyDataset(train_mat, train_df['tag'])train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)for batch in train_loader:print(batch)breakreturn train_mat,train_df,test_mat

5,定义模型网络

def model_create(mat_dim):# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)paddle.summary(model, (64, 1, 9, mat_dim))return model

6,模型训练

def model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir):train_mat, train_df, test_mat = load_data(BATCH_SIZE)skf = StratifiedKFold(n_splits=num_splits)if (not os.path.exists(output_dir)):os.mkdir(output_dir)fold_idx = 0for tr_idx, val_idx in skf.split(train_mat, train_df['tag'].values):train_ds = MyDataset(np.array(train_mat)[tr_idx], train_df['tag'].values[tr_idx], mat_dim=mat_dim)dev_ds = MyDataset(np.array(train_mat)[val_idx], train_df['tag'].values[val_idx], mat_dim=mat_dim)Train_Loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)Val_Loader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=True)# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)optimizer = optim.Adam(parameters=model.parameters(), learning_rate=LEARNING_RATE)criterion = nn.BCEWithLogitsLoss()Test_best_Acc = 0for epoch in range(0, EPOCHS):Train_Loss, Test_Loss = [], []Train_Acc, Test_Acc = [], []model.train()for i, (x, y) in enumerate(Train_Loader):pred = model(x)loss = criterion(pred, y)Train_Loss.append(loss.item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Train_Acc.append((pred.numpy() == y.numpy()).mean())loss.backward()optimizer.step()optimizer.clear_grad()model.eval()for i, (x, y) in enumerate(Val_Loader):pred = model(x)Test_Loss.append(criterion(pred, y).item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Test_Acc.append((pred.numpy() == y.numpy()).mean())if epoch % 10 == 0:print("Epoch: [{}/{}] TrainLoss/TestLoss: {:.4f}/{:.4f} TrainAcc/TestAcc: {:.4f}/{:.4f}".format( \epoch + 1, EPOCHS, \np.mean(Train_Loss), np.mean(Test_Loss), \np.mean(Train_Acc), np.mean(Test_Acc) \))if Test_best_Acc < np.mean(Test_Acc):print(f'Fold {fold_idx} Acc imporve from {Test_best_Acc} to {np.mean(Test_Acc)} Save Model...')paddle.save(model.state_dict(), os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))Test_best_Acc = np.mean(Test_Acc)fold_idx += 1

 7,载入模型预测(得先训练)

def model_predict(modelpath,mat_dim,answer_path):train_mat, train_df, test_mat = load_data(BATCH_SIZE)test_perd = np.zeros(len(test_mat))tta_count = 20model = model_create(mat_dim)layer_state_dict = paddle.load(modelpath + '/model_4.pdparams')model.set_state_dict(layer_state_dict)for fold_idx in range(num_splits):test_ds = MyDataset(test_mat, [0] * len(test_mat), mat_dim=mat_dim)Test_Loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)layer_state_dict = paddle.load(os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))model.set_state_dict(layer_state_dict)for tta in range(tta_count):test_pred_list = []for i, (x, y) in enumerate(Test_Loader):pred = model(x)test_pred_list.append(paddle.nn.functional.sigmoid(pred).numpy())test_perd += np.vstack(test_pred_list)[:, 0]test_perd /= tta_count * num_splits# 生成结果test_path = glob.glob('./data/ecg/val/*.mat')test_path = [os.path.basename(x)[:-4] for x in test_path]test_path.sort()test_answer = pd.DataFrame({'name': test_path,'tag': (test_perd > 0.5).astype(int)}).to_csv(answer_path, index=None)

8,预测结果

备注

调试步骤及临时配参

#### 测试
EPOCHS = 200
BATCH_SIZE = 30
LEARNING_RATE = 0.0005
num_splits = 5
mat_dim = 4000
output_dir = 'checkpoint'
answer_path = './data/ecg/answer.csv'# model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir)
model_predict(output_dir,mat_dim,answer_path)

这篇关于PaddlePaddle笔记5-心电图智能诊断的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/265451

相关文章

使用Python实现表格字段智能去重

《使用Python实现表格字段智能去重》在数据分析和处理过程中,数据清洗是一个至关重要的步骤,其中字段去重是一个常见且关键的任务,下面我们看看如何使用Python进行表格字段智能去重吧... 目录一、引言二、数据重复问题的常见场景与影响三、python在数据清洗中的优势四、基于Python的表格字段智能去重

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

Python3脚本实现Excel与TXT的智能转换

《Python3脚本实现Excel与TXT的智能转换》在数据处理的日常工作中,我们经常需要将Excel中的结构化数据转换为其他格式,本文将使用Python3实现Excel与TXT的智能转换,需要的可以... 目录场景应用:为什么需要这种转换技术解析:代码实现详解核心代码展示改进点说明实战演练:从Excel到

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

智能交通(二)——Spinger特刊推荐

特刊征稿 01  期刊名称: Autonomous Intelligent Systems  特刊名称: Understanding the Policy Shift  with the Digital Twins in Smart  Transportation and Mobility 截止时间: 开放提交:2024年1月20日 提交截止日

基于 YOLOv5 的积水检测系统:打造高效智能的智慧城市应用

在城市发展中,积水问题日益严重,特别是在大雨过后,积水往往会影响交通甚至威胁人们的安全。通过现代计算机视觉技术,我们能够智能化地检测和识别积水区域,减少潜在危险。本文将介绍如何使用 YOLOv5 和 PyQt5 搭建一个积水检测系统,结合深度学习和直观的图形界面,为用户提供高效的解决方案。 源码地址: PyQt5+YoloV5 实现积水检测系统 预览: 项目背景