PaddlePaddle笔记5-心电图智能诊断

2023-10-23 03:40

本文主要是介绍PaddlePaddle笔记5-心电图智能诊断,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

睿洛医疗

参考资料:

AIWIN - 比赛

AIWIN 心电图智能诊断竞赛:TextCNN 精度0.67 - 飞桨AI Studio - 人工智能学习与实训社区

        心电图是临床最基础的一个检查项目,因为安全、便捷成为心脏病诊断的利器。每天都有大量的心电图诊断需求,但是全国范围内诊断心电图的专业医生数量不足,导致很多医院都面临专业心电图医生短缺的情况。人工智能技术的出现,为改善医生人力资源不足的问题带来了全新的可能。由于心电图数据与诊断的标准化程度较高,相对较易于运用人工智能技术进行智能诊断算法的开发。由于心电图可诊断的疾病类别特别丰富,目前,市面上出现较多的是针对某些特定类别的算法,尚没有看到能够按照临床诊断标准、在一定准确率标准下,提供类似医生的多标签多分类算法。

        需要识别的心电图包括12个类别:正常心电图、窦性心动过缓、窦性心动过速、窦性心律不齐、心房颤动、室性早搏、房性早搏、一度房室阻滞、完全性右束支阻滞、T波改变、ST改变、其它。

         心电数据的单位为mV,采样率为 500HZ,记录时长为 10 秒,存储格式为 MAT;文件中存储了 12 导联的电压信号(包含了I,II,III,aVR,aVL,aVF,V1,V2,V3,V4,V5 和 V6)。数据下载地址:心电图测试数据集https://download.csdn.net/download/zj850324/67399060https://download.csdn.net/download/zj850324/67399060

数据格式:

12导联的数据,保存matlab格式文件中。数据格式是(12, 5000);

采样500HZ,10S长度有效数据;

0..12是I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5和V6数据,单位是mV。

1,准备模块(依赖库按需安装即可,不再赘述)

import scipy.io as sio
import matplotlib.pyplot as plt
import codecs, glob, os
import numpy as np
import pandas as pd
import paddle
import paddle.nn as nn
from paddle.io import DataLoader, Dataset
import paddle.optimizer as optim
from paddlenlp.data import Pad
import scipy.io as sio
from sklearn.model_selection import StratifiedKFold

2,构造数据集

class MyDataset(Dataset):def __init__(self, mat, label, mat_dim=3000):super(MyDataset, self).__init__()self.mat = matself.label = labelself.mat_dim = mat_dimdef __len__(self):return len(self.mat)def __getitem__(self, index):idx = np.random.randint(0, 5000-self.mat_dim)# idy = np.random.choice(range(12), 9)inputs=paddle.to_tensor(self.mat[index][:, :, idx:idx+self.mat_dim])label=paddle.to_tensor(self.label[index])return inputs,label

3,构造模型

class TextCNN_Plus(nn.Layer):def __init__(self, kernel_num=30, kernel_size=[3, 4, 5], dropout=0.5,mat_dim=3000):super(TextCNN_Plus, self).__init__()self.kernel_num = kernel_numself.kernel_size = kernel_sizeself.dropout = dropoutself.convs = nn.LayerList([nn.Conv2D(1, self.kernel_num, (kernel_size_, mat_dim))for kernel_size_ in self.kernel_size])self.dropout = nn.Dropout(self.dropout)self.linear = nn.Linear(3 * self.kernel_num, 1)def forward(self, x):convs = [nn.ReLU()(conv(x)).squeeze(3) for conv in self.convs]pool_out = [nn.MaxPool1D(block.shape[2])(block).squeeze(2) for block in convs]pool_out = paddle.concat(pool_out, 1)logits = self.linear(pool_out)return logits

4,加载数据集

def load_data(BATCH_SIZE):train_mat = glob.glob('./data/ecg/train/*.mat')train_mat.sort()train_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in train_mat]test_mat = glob.glob('./data/ecg/val/*.mat')test_mat.sort()test_mat = [sio.loadmat(x)['ecgdata'].reshape(1, 12, 5000) for x in test_mat]train_df = pd.read_csv('./data/ecg/trainreference.csv')train_df['tag'] = train_df['tag'].astype(np.float32)# 查看数据plt.plot(range(5000), train_mat[0][0][0])plt.plot(range(5000), train_mat[0][0][1])plt.plot(range(5000), train_mat[0][0][3])# plt.show()train_df.head()print(test_mat[0].shape)train_ds = MyDataset(train_mat, train_df['tag'])train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)for batch in train_loader:print(batch)breakreturn train_mat,train_df,test_mat

5,定义模型网络

def model_create(mat_dim):# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)paddle.summary(model, (64, 1, 9, mat_dim))return model

6,模型训练

def model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir):train_mat, train_df, test_mat = load_data(BATCH_SIZE)skf = StratifiedKFold(n_splits=num_splits)if (not os.path.exists(output_dir)):os.mkdir(output_dir)fold_idx = 0for tr_idx, val_idx in skf.split(train_mat, train_df['tag'].values):train_ds = MyDataset(np.array(train_mat)[tr_idx], train_df['tag'].values[tr_idx], mat_dim=mat_dim)dev_ds = MyDataset(np.array(train_mat)[val_idx], train_df['tag'].values[val_idx], mat_dim=mat_dim)Train_Loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)Val_Loader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=True)# model = TextCNN()model = TextCNN_Plus(mat_dim=mat_dim)optimizer = optim.Adam(parameters=model.parameters(), learning_rate=LEARNING_RATE)criterion = nn.BCEWithLogitsLoss()Test_best_Acc = 0for epoch in range(0, EPOCHS):Train_Loss, Test_Loss = [], []Train_Acc, Test_Acc = [], []model.train()for i, (x, y) in enumerate(Train_Loader):pred = model(x)loss = criterion(pred, y)Train_Loss.append(loss.item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Train_Acc.append((pred.numpy() == y.numpy()).mean())loss.backward()optimizer.step()optimizer.clear_grad()model.eval()for i, (x, y) in enumerate(Val_Loader):pred = model(x)Test_Loss.append(criterion(pred, y).item())pred = (paddle.nn.functional.sigmoid(pred) > 0.5).astype(int)Test_Acc.append((pred.numpy() == y.numpy()).mean())if epoch % 10 == 0:print("Epoch: [{}/{}] TrainLoss/TestLoss: {:.4f}/{:.4f} TrainAcc/TestAcc: {:.4f}/{:.4f}".format( \epoch + 1, EPOCHS, \np.mean(Train_Loss), np.mean(Test_Loss), \np.mean(Train_Acc), np.mean(Test_Acc) \))if Test_best_Acc < np.mean(Test_Acc):print(f'Fold {fold_idx} Acc imporve from {Test_best_Acc} to {np.mean(Test_Acc)} Save Model...')paddle.save(model.state_dict(), os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))Test_best_Acc = np.mean(Test_Acc)fold_idx += 1

 7,载入模型预测(得先训练)

def model_predict(modelpath,mat_dim,answer_path):train_mat, train_df, test_mat = load_data(BATCH_SIZE)test_perd = np.zeros(len(test_mat))tta_count = 20model = model_create(mat_dim)layer_state_dict = paddle.load(modelpath + '/model_4.pdparams')model.set_state_dict(layer_state_dict)for fold_idx in range(num_splits):test_ds = MyDataset(test_mat, [0] * len(test_mat), mat_dim=mat_dim)Test_Loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)layer_state_dict = paddle.load(os.path.join(output_dir, "model_{}.pdparams".format(fold_idx)))model.set_state_dict(layer_state_dict)for tta in range(tta_count):test_pred_list = []for i, (x, y) in enumerate(Test_Loader):pred = model(x)test_pred_list.append(paddle.nn.functional.sigmoid(pred).numpy())test_perd += np.vstack(test_pred_list)[:, 0]test_perd /= tta_count * num_splits# 生成结果test_path = glob.glob('./data/ecg/val/*.mat')test_path = [os.path.basename(x)[:-4] for x in test_path]test_path.sort()test_answer = pd.DataFrame({'name': test_path,'tag': (test_perd > 0.5).astype(int)}).to_csv(answer_path, index=None)

8,预测结果

备注

调试步骤及临时配参

#### 测试
EPOCHS = 200
BATCH_SIZE = 30
LEARNING_RATE = 0.0005
num_splits = 5
mat_dim = 4000
output_dir = 'checkpoint'
answer_path = './data/ecg/answer.csv'# model_train(EPOCHS,BATCH_SIZE,LEARNING_RATE,num_splits,mat_dim,output_dir)
model_predict(output_dir,mat_dim,answer_path)

这篇关于PaddlePaddle笔记5-心电图智能诊断的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/265451

相关文章

基于Python实现智能天气提醒助手

《基于Python实现智能天气提醒助手》这篇文章主要来和大家分享一个实用的Python天气提醒助手开发方案,这个工具可以方便地集成到青龙面板或其他调度框架中使用,有需要的小伙伴可以参考一下... 目录项目概述核心功能技术实现1. 天气API集成2. AI建议生成3. 消息推送环境配置使用方法完整代码项目特点

JavaScript实战:智能密码生成器开发指南

本文通过JavaScript实战开发智能密码生成器,详解如何运用crypto.getRandomValues实现加密级随机密码生成,包含多字符组合、安全强度可视化、易混淆字符排除等企业级功能。学习密码强度检测算法与信息熵计算原理,获取可直接嵌入项目的完整代码,提升Web应用的安全开发能力 目录

利用Python实现Excel文件智能合并工具

《利用Python实现Excel文件智能合并工具》有时候,我们需要将多个Excel文件按照特定顺序合并成一个文件,这样可以更方便地进行后续的数据处理和分析,下面我们看看如何使用Python实现Exce... 目录运行结果为什么需要这个工具技术实现工具的核心功能代码解析使用示例工具优化与扩展有时候,我们需要将

基于Python打造一个智能单词管理神器

《基于Python打造一个智能单词管理神器》这篇文章主要为大家详细介绍了如何使用Python打造一个智能单词管理神器,从查询到导出的一站式解决,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 项目概述:为什么需要这个工具2. 环境搭建与快速入门2.1 环境要求2.2 首次运行配置3. 核心功能使用指

Python实现word文档内容智能提取以及合成

《Python实现word文档内容智能提取以及合成》这篇文章主要为大家详细介绍了如何使用Python实现从10个左右的docx文档中抽取内容,再调整语言风格后生成新的文档,感兴趣的小伙伴可以了解一下... 目录核心思路技术路径实现步骤阶段一:准备工作阶段二:内容提取 (python 脚本)阶段三:语言风格调

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

使用Python实现表格字段智能去重

《使用Python实现表格字段智能去重》在数据分析和处理过程中,数据清洗是一个至关重要的步骤,其中字段去重是一个常见且关键的任务,下面我们看看如何使用Python进行表格字段智能去重吧... 目录一、引言二、数据重复问题的常见场景与影响三、python在数据清洗中的优势四、基于Python的表格字段智能去重

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

Python3脚本实现Excel与TXT的智能转换

《Python3脚本实现Excel与TXT的智能转换》在数据处理的日常工作中,我们经常需要将Excel中的结构化数据转换为其他格式,本文将使用Python3实现Excel与TXT的智能转换,需要的可以... 目录场景应用:为什么需要这种转换技术解析:代码实现详解核心代码展示改进点说明实战演练:从Excel到