运动想象 (MI) 迁移学习系列 (5) : SSMT

2024-03-11 12:28

本文主要是介绍运动想象 (MI) 迁移学习系列 (5) : SSMT,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

运动想象迁移学习系列:SSMT

  • 0. 引言
  • 1. 主要贡献
  • 2. 网络结构
  • 3. 算法
  • 4. 补充
    • 4.1 为什么设置一种新的适配器?
    • 4.2 动态加权融合机制究竟是干啥的?
  • 5. 实验结果
  • 6. 总结
  • 欢迎来稿

论文地址:https://link.springer.com/article/10.1007/s11517-024-03032-z
论文题目:Semi-supervised multi-source transfer learning for cross-subject EEG motor imagery classification
论文代码:无

0. 引言

脑电图(EEG)运动意象(MI)分类是指利用脑电信号对受试者的运动意象活动进行识别和分类;随着脑机接口(BCI)的发展,这项任务越来越受到关注。然而,脑电图数据的收集通常是耗时且劳动密集型的,这使得很难从新受试者那里获得足够的标记数据来训练新模型。此外,不同个体的脑电信号表现出显着差异,导致在直接对从新受试者获得的脑电信号进行分类时,在现有受试者上训练的模型的性能显着下降。因此,充分利用现有受试者的脑电数据和新目标受试者的未标记脑电数据,提高目标受试者达到的心肌梗死分类性能至关重要。本研究提出了一种半监督多源迁移(SSMT)学习模型来解决上述问题;该模型学习信息和域不变表示,以解决跨主题的 MI-EEG 分类任务。具体而言,该文提出了一种动态转移加权模式,通过整合从多源域派生的加权特征来获得最终预测。

文中主要解决方法是针对无监督的脑电数据迁移学习方案,是一个不错的角度,也提出了很有新意的算法设计!!!

1. 主要贡献

  1. 一种基于 MMDCMMD域适应方法,用于解决单个 MI-EEG信号差异的问题,对齐每个源域和靶域之间的条件和边际分布差异。此外,伪标签被应用于目标域的未标记数据,并在整个训练过程中迭代更新。通过这种方式,条件分布信息将更新为近似真实的条件分布。
  2. 基于域间差异度量设计了一种动态权重转移模型,使每个源域能够根据其与目标域的相似性为训练过程做出贡献。因此,通过减轻与目标域显著差异的源域的不利影响,可以进一步提高分类器对目标域的预测性能。
  3. 通过一系列实验,在两个公开可用的 BCI数据集上评估了所提出的方法。结果表明,所提方法的每一项创新都有助于提高解码性能,与基线相比,解码性能更好。

2. 网络结构

在这里插入图片描述
SSMT两个主要阶段组成。预训练阶段预训练所有可用于在特征提取任务和原始监督分类任务中训练的标记数据,以获得仅包含特征提取器和分类器的全局模型。然后,利用预训练模型对目标域的未标记数据进行伪标记;再训练阶段包括三个主要步骤。首先,域适配器旨在减少每个源域和目标域之间的差异。然后,使用伪标签信息并不断更新以优化模型。最后,最终决策由MLP分类器的转移权重融合产生。

3. 算法

符号说明
{ X s k , y s k } k = 1 n \{X_s^k, y_s^k\}_{k=1}^n {Xsk,ysk}k=1n 表示存在n个源域 X t X_t Xt 表示目标域,包含两个部分,分别是 X l X_l Xl X u X_u Xu; X l X_l Xl y l y_l yl 表示目标域中已知(标记)的样本 X u X_u Xu 表示目标域中未标记的样本,即也不知道其对应的类别。

SSMT算法步骤

输入: { X s k , y s k } k = 1 n , X l , y l , X u \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l, X_u {Xsk,ysk}k=1n,Xl,yl,Xu

  1. 初始化权重参数 θ f , θ c \theta_f, \theta_c θf,θc

  2. 通过输入 { X s k , y s k } k = 1 n , X l , y l \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l {Xsk,ysk}k=1n,Xl,yl 直接训练预训练模型中的特征提取器 G f G_f Gf 和MLP分类器 G c G_c Gc , 并根据下面等式更新参数 θ f , θ c \theta_f, \theta_c θf,θc L c = − ∑ k = 1 n y s k ⋅ log ⁡ ( G c ( G f ( X s k ; θ f ) ; θ c ) ) − y l ⋅ log ⁡ ( G c ( G f ( X l ; θ f ) ; θ c ) ) , \begin{aligned} L_c= & {} -\sum _{k=1}^n \textbf{y}^k_s\cdot \log (G_c(G_f(\textbf{X}^k_s;\theta _f);\theta _c))\nonumber \\{} & {} -\textbf{y}_l\cdot \log (G_c(G_f(\textbf{X}_l;\theta _f);\theta _c)), \end{aligned} Lc=k=1nysklog(Gc(Gf(Xsk;θf);θc))yllog(Gc(Gf(Xl;θf);θc)),

  3. 生成测试集的伪标签: y ^ u = G c ( G f ( X u ; θ f ) ; θ c ) , \begin{aligned} \hat{\textbf{y}}_u=G_c(G_f(\textbf{X}_u;\theta _f);\theta _c), \end{aligned} y^u=Gc(Gf(Xu;θf);θc), 预训练阶段结束

  4. X l X_l Xl X u X_u Xu 的数据合并为目标域 X t X_t Xt,并连接所有域的数据(将 X s k X_s^k Xsk X t X_t Xt 的数据进行连接)

  5. 重复

  6. 将连接的数据输入 G f G_f Gf 来得到所有域的特征:
    F = [ G f ( X s 1 ; θ f ) , . . . , G f ( X s n ; θ f ) , G f ( X t ; θ f ) ] T F=[G_f(X_s^1;\theta_f),...,G_f(X_s^n;\theta_f),G_f(X_t;\theta_f)]^T F=[Gf(Xs1;θf),...,Gf(Xsn;θf),Gf(Xt;θf)]T

  7. 根据以下公式获取每个源域的差异损失转移权重: L d k = M M D ( D s k , D t ) + C M M D ( D s k , D t ) . \begin{aligned} L_d^k=MMD(\mathcal {D}^k_s, \mathcal {D}_t)+CMMD(\mathcal {D}^k_s, \mathcal {D}_t). \end{aligned} Ldk=MMD(Dsk,Dt)+CMMD(Dsk,Dt). C M M D ( D s k , D t ) = ∑ c = 1 C ∥ 1 m c ∑ x s k , i ∣ y s k , i = c ϕ ( G f ( x s k , i ; θ f ) ) − 1 n ^ c + n c ( ∑ x l i ∣ y l i = c ϕ ( G f ( x l i ; θ f ) ) + ∑ x u i ∣ y ^ u i = c ϕ ( G f ( x u i ; θ f ) ) ∥ , \begin{aligned} CMMD(\mathcal {D}^k_s, \mathcal {D}_t)= & {} \sum _{c=1}^C\Vert \frac{1}{m_c} \sum _{\textbf{x}_s^{k,i} |y^{k,i}_s=c} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} -\frac{1}{\hat{n}_c+n_c}(\sum _{\textbf{x}_l^i |{y}_l^i=c} \phi (G_f(\textbf{x}_l^i;\theta _f))\nonumber \\{} & {} +\sum _{\textbf{x}_u^i |\hat{y}_u^i=c} \phi (G_f(\textbf{x}_u^i;\theta _f))\Vert , \end{aligned} CMMD(Dsk,Dt)=c=1Cmc1xsk,iysk,i=cϕ(Gf(xsk,i;θf))n^c+nc1(xliyli=cϕ(Gf(xli;θf))+xuiy^ui=cϕ(Gf(xui;θf)), M M D ( D s k , D t ) = ∥ 1 n s k ∑ i = 1 n s k ϕ ( G f ( x s k , i ; θ f ) ) − 1 n t ∑ i = 1 n t ϕ ( G f ( x t i ; θ f ) ) ∥ , \begin{aligned} MMD\left( \mathcal {D}^k_s, \mathcal {D}_t\right)= & {} \Bigg \Vert \frac{1}{n^k_s} \sum _{i=1}^{n^k_s} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} - \frac{1}{n_t} \sum _{i=1}^{n_t} \phi (G_f(\textbf{x}_t^i;\theta _f))\Bigg \Vert , \end{aligned} MMD(Dsk,Dt)= nsk1i=1nskϕ(Gf(xsk,i;θf))nt1i=1ntϕ(Gf(xti;θf)) ,

  8. 基于下面式子对每个域的特征进行动态加权,然后将 F ∗ F^* F 作为 G c G_c Gc 的输入:

    w = [ W d 1 , … , W d n ] ⊤ = [ K − L d 1 2 ∑ k = 1 n K − L d k 2 , … , K − L d n 2 ∑ k = 1 n K − L d k 2 ] ⊤ , \begin{aligned} \textbf{w}= & {} [W^1_d, \ldots , W^n_d]^{\top }\nonumber \\= & {} \left[ \frac{K^{- {L_d^1}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}, \ldots , \frac{K^{- {L_d^n}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}\right] ^{\top }, \end{aligned} w==[Wd1,,Wdn][k=1nKLdk2KLd12,,k=1nKLdk2KLdn2], F ∗ = [ F s 1 ∗ , … , F s n ∗ , F t ] ⊤ = [ W d 1 F s 1 , … , W d n F s n , F t ] ⊤ , \begin{aligned} \textbf{F}^*=[{\textbf{F}^1_s}^*,\ldots ,{\textbf{F}^n_s}^*,\textbf{F}_t]^\top =[W^1_d\textbf{F}^1_s,\ldots ,W^n_d\textbf{F}^n_s,\textbf{F}_t]^\top , \end{aligned} F=[Fs1,,Fsn,Ft]=[Wd1Fs1,,WdnFsn,Ft],

  9. 根据下面等式,通过最小化 L L L 更新参数 θ f , θ c \theta_f, \theta_c θf,θc

L = L c + λ L d , \begin{aligned} L=L_c+\lambda L_d, \end{aligned} L=Lc+λLd,

  1. 通过预测 X u X_u Xu 更新 y ^ u \hat{y}_u y^u

  2. 直到收敛

  3. 返回 y ^ u \hat{y}_u y^u

4. 补充

4.1 为什么设置一种新的适配器?

最近的研究表明,随着域间差异的增加,分类器对特征的可转移性显着降低,这表明直接转移提取的特征是一种不安全的策略。因此,在不考虑个体信号差异的情况下,使用所有可用数据进行预训练的模型可能会导致目标受试者分类的性能下降。为了防止传统两级流水线引起的分布过拟合问题,设计了一种域适配器来减轻单个信号差异的负面影响。

尽管经典MMD已被广泛用作分布差异度量,但现有研究表明,在处理类权重偏差(即类不平衡数据)时,MMD并不总是可靠的。调查发现类条件分布之间的差异 P s ( x s k , i ∣ y s k , i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) Ps(xsk,iysk,i=c) P t ( x l i ∣ y l i = c ) P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Pt(xliyli=c)可以提供更合适的域差异量表,并导致卓越的域适应性能。什么时候 P s ( x s k , i ∣ y s k , i = c ) = P t ( x l i ∣ y l i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) =P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Ps(xsk,iysk,i=c)=Pt(xliyli=c),在源域中学习的分类器可以更安全地应用于目标域。基于这一概念,引入了条件最大均值差异(CMMD)度量,以对齐所有源域和目标域特征的类条件分布.

4.2 动态加权融合机制究竟是干啥的?

从所有数据中获得的特征 G f G_f Gf 可直接用于输入 G c G_c Gc 用于训练,但分类器的这种无歧视训练输入可能会导致不良结果。这一结果可归因于负转移当通过蛮力利用与目标关系不相关的来源时,就会发生负转移,从而导致对目标域的分类器预测有偏差。

为了减轻负迁移的影响,分类器被赋予了动态加权特征,用于最终决策融合。

5. 实验结果

对比实验结果:
在这里插入图片描述
消融实验结果:

在这里插入图片描述

  • PT:PT是仅包含特征提取器和MLP分类器的基本模型,可以完成简单的特征提取和分类任务。
  • DA:域适配器 (DA) 基于 MMD 和 CMMD。特别是,DA 仅使用通过预训练生成的伪标签来计算域间差异。
  • SS:SS 是一个迭代标签更新器。它的作用是在重新训练过程中周期性地生成和更新伪标签。
  • WF:WF是指动态加权模型,它对来自多源域的加权特征进行动态加权和整合。

6. 总结

到此,使用 SSMT 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

欢迎来稿

欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!

这篇关于运动想象 (MI) 迁移学习系列 (5) : SSMT的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/797804

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学