【论文泛读】OCT-GAN(WWW’21)

2024-03-22 02:30
文章标签 21 oct 论文 www gan 泛读

本文主要是介绍【论文泛读】OCT-GAN(WWW’21),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Jayoung Kim, Jinsung Jeon, Jaehoon Lee, Jihyeon Hyeong, Noseong Park

Yonsei University

原文传送

摘要

  • 表格数据的生成,为人们增加了训练数据。最先进的方法在数据不平衡分布和模式崩溃问题上还不令人满意。
  • 主要贡献:
    • 鉴别器有一个ODE层来提取一个隐藏的向量演化轨迹进行分类;
    • 轨迹由在不同层(或时间)𝑡𝑖上提取的一系列隐藏向量表示。还训练了这些提取时间点;
    • 基于轨迹的分类给鉴别器带来了重要的好处,因为不仅可以使用最后一个隐藏向量,还可以使用轨迹中包含的所有信息;
    • 生成器采用一个初始ODE层将𝒛⊕𝒄转换为另一个适合生成的潜在隐藏向量𝒛‘,同时保持原语义(同态映射);
    • 总共对13个数据集进行了深入的实验(一部分仅为likelihood实验),从保险欺诈检测到在线新闻文章传播预测等。评估任务包括生成假据用于似然估计、分类、回归和聚类,并且方法在许多情况下都大大优于现有的方法。

介绍

  • web-based 应用大多使用表格型数据,并且许多企业系统使用关系数据库管理系统。
  • 表格数据通常具有不规则分布和多峰性,基于引入Neural ODEs的想法,设计生成器和鉴别器,显著提高了效用。
  • 基于node的鉴别器,执行一个基于隐藏向量进化轨迹的分类;
  • 结构如图:
    在这里插入图片描述

Backgrounds

  • Neural Ordinary Differential Equations (NODEs)
    • 在NODEs中,一个具有一组参数的神经网络f,记为 θ ( f ) \theta(f) θ(f)
    • h ( t m ) = h ( t 0 ) + ∫ t 0 t m f ( h ( t ) , t ; θ f ) d t h(t_m)=h(t_0)+\int_{t_0}^{t_m}f(h(t),t;\theta_f)dt h(tm)=h(t0)+t0tmf(h(t),t;θf)dt, 其中 f ( h ( t ) , t ; θ f ) = d h ( t ) / d t f(h(t),t;\theta_f)=dh(t)/dt f(h(t),t;θf)=dh(t)/dt通常的神经网络中,t是离散的,但是在NODEs中t是连续的
    • 依赖ODE求解器解决积分问题。文章依赖DOPRI方法,将积分转为一系列的加分,DOPRI能够动态地控制其步长。
  • Conditional GAN
    • 整体还是基于MLP;
    • 原生GAN中,G的目的是欺骗D获得高分。但是CGAN的G需要输入的condition,D除了打分外还需要判断是否满足condition。
  • Tabular Data Synthesis
    • RGAN 生成连续的时间序列医疗保健记录;
    • EhrGAN 使用半监督学习生成看似合理的标记记录,以增加有限的训练数据;
    • PATE-GAN在不危及原始数据隐私的情况下生成合成数据;
    • TableGAN利用卷积神经网络改进了表格数据的合成,以最大限度地提高标签列上的预测精度。

Main methods

数据预处理

主要考虑两类数据

  • 离散型数据 D 1 , D 2 , D 3 . . . D N D D_1,D_2,D_3...D_{N_D} D1,D2,D3...DND,被转换为one-hot向量
  • 连续型数据 C 1 , C 2 , . . . , C N C C_1,C_2,...,C_{N_C} C1,C2,...,CNC,使用mode-specifific normalization(和CTGAN方法一致)进行预处理

第i行的数据 r i r_i ri可以被写为 d i , 1 ⨁ d i , 2 . . . . . . ⨁ d i , N D ⨁ c i , 1 ⨁ c i , 2 . . . . . . ⨁ c i , N C d_{i,1}\bigoplus d_{i,2}......\bigoplus d_{i,N_D}\bigoplus c_{i,1}\bigoplus c_{i,2}......\bigoplus c_{i,N_C} di,1di,2......di,NDci,1ci,2......ci,NC,通过以下三个步骤将数据 r i r_i ri预处理为 x i x_i xi

  1. 将每个离散值 d i , 1 , d i , 2 , . . . , d i , N D d_{i,1},d_{i,2},...,d_{i,N_D} di,1,di,2,...,di,ND转换为一个one-hot向量 d o i , 1 , d o i , 2 , . . . , d o i , N D d_{oi,1},d_{oi,2},...,d_{oi,N_D} doi,1,doi,2,...,doi,ND

  2. 利用变分高斯混合(VGM)模型,将每个连续列𝐶𝑗拟合到一个高斯混合,高斯混合的表示为: P r j ( c i , j ) = Σ k = 1 n j w j , k N ( c i , j ; μ j , k , σ j , k ) Pr_j(c_{i,j})=\Sigma_{k=1}^{n_j}w_{j,k}\Nu(c_{i,j};\mu_{j,k},\sigma_{j,k}) Prj(ci,j)=Σk=1njwj,kN(ci,j;μj,k,σj,k)其中, n j n_j nj C j C_j Cj列中的模数(即高斯分布的数)。 w j , k , μ j , k , σ j , k w_{j,k},\mu_{j,k},\sigma_{j,k} wj,k,μj,k,σj,k k t h k_{th} kth高斯分布的拟合权值、均值和标准差

  3. 对于 P r j ( k ) = w j , k N ( c i , j ; μ j , k , σ j , k ) Σ p = 1 n j w j , p N ( c i , j ; μ j , p , σ j , p ) Pr_j(k)=\frac{w_{j,k}\Nu(c_{i,j};\mu_{j,k},\sigma_{j,k})}{\Sigma_{p=1}^{n_j}w_{j,p}\Nu(c_{i,j};\mu_{j,p},\sigma_{j,p})} Prj(k)=Σp=1njwj,pN(ci,j;μj,p,σj,p)wj,kN(ci,j;μj,k,σj,k),以合适的模式k对 c i , j c_{i,j} ci,j采样,然后将模式k中的 c i , j c_{i,j} ci,j及其拟合的标准差进行归一化,保存归一化值 α i , j \alpha_{i,j} αi,j和模式信息 β i , j \beta_{i,j} βi,j

  4. 最后, r i r_i ri转化为 x i x_i xi,其表示为: x i = α i , 1 ⨁ β i , 1 ⨁ . . . ⨁ α i , N c ⨁ β i , N c ⨁ d o i , 1 ⨁ . . . ⨁ d o i , N c x_i=\alpha_{i,1}\bigoplus\beta_{i,1}\bigoplus...\bigoplus\alpha_{i,N_c}\bigoplus\beta_{i,N_c}\bigoplus d_{oi,1}\bigoplus...\bigoplus d_{oi,N_c} xi=αi,1βi,1...αi,Ncβi,Ncdoi,1...doi,Nc

    x i x_i xi中包含 r i r_i ri基于模式的信息,GAN的G和D可以使用 x i x_i xi来分辨模式。同时,使用高斯混合的拟合参数, x i x_i xi可以很容易地更改为 r i r_i ri

鉴别器 Discriminator

在这里插入图片描述

基于ODE的鉴别器,并在预测输入样本 x ( t ) x(t) x(t)的真否同时,考虑了 h h h的轨迹。

  1. h ( 0 ) = D r o p ( L e a k y ( F C 2 ( D r o p ( L e a k y ( F C 1 ( x ) ) ) ) ) ) h(0)=Drop(Leaky(FC2(Drop(Leaky(FC1(x)))))) h(0)=Drop(Leaky(FC2(Drop(Leaky(FC1(x))))))
  2. h ( t 1 ) = h ( 0 ) + ∫ 0 t 1 f ( h ( 0 ) , t ; θ f ) d t h(t_1)=h(0)+\intop_0^{t_1}f(h(0),t;\theta_f)dt h(t1)=h(0)+0t1f(h(0),t;θf)dt
  3. h ( t 2 ) = h ( t 1 ) + ∫ t 1 t 2 f ( h ( t 1 ) , t ; θ f ) d t h(t_2)=h(t_1)+\intop_{t_1}^{t_2}f(h(t_1),t;\theta_f)dt h(t2)=h(t1)+t1t2f(h(t1),t;θf)dt
  4. h ( t m ) = h ( t m − 1 ) + ∫ t m − 1 t m f ( h ( t m − 1 ) , t ; θ f ) d t h(t_m)=h(t_{m-1})+\intop_{t_{m-1}}^{t_{m}}f(h(t_{m-1}),t;\theta_f)dt h(tm)=h(tm1)+tm1tmf(h(tm1),t;θf)dt
  • h x = h ( 0 ) ⨁ h ( t 1 ) ⨁ h ( t 2 ) ⨁ . . . ⨁ h ( t m ) h_x=h(0)\bigoplus h(t_1)\bigoplus h(t_2)\bigoplus ...\bigoplus h(t_m) hx=h(0)h(t1)h(t2)...h(tm)

  • D ( x ) = F C 5 ( L e a k y ( F C 4 ( L e a k y ( F C 3 ( h x ) ) ) ) ) D(x)=FC5(Leaky(FC4(Leaky(FC3(h_x))))) D(x)=FC5(Leaky(FC4(Leaky(FC3(hx)))))

    其中,m是超参数,(3)到(6)共享相同的参数 θ f \theta_f θf,构成单一的ODE系统,定义 a t ( t ) = d L d t a_t(t) = \frac{dL}{dt} at(t)=dtdL

损失Loss的计算
在这里插入图片描述
综上,只需要保存一个伴随的数位 a h ( t m ) a_h(t_m) ah(tm),并使用两个函数𝑓和 a h ( t ) a_h(t) ah(t)计算∇𝑡𝑖L

h ( t m ) h(t_m) h(tm)是最后隐藏向量。使用 h ( t m ) h(t_m) h(tm)和整个轨迹来进行分类。

  • 通过寻找关键的时间点来区分轨迹,进一步提高了该方法的有效性;
  • 在通常的神经网络中,训练𝑡𝑖是不可能的,因为它们的层结构是离散的,利用ODE的性质,可以选择最佳的 t i t_i ti节点。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-E4Cje7o3-1666617360117)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/41452542-71db-40c9-ac43-51cbfd3fdbca/Untitled.png)]

条件生成器 Conditional Generator

OCT-GAN是一个条件GAN,其生成器读取一个噪声向量和一个条件向量来生成一个假样本。

给定一个初始输入𝒑(0)=𝒛⊕𝒄,将其送入一个ODE层,以转换为另一个潜在向量。

在这里插入图片描述
ODE是一个同态映射。可以利用这个特性来设计一个语义上可靠的映射。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3mwpI7Tu-1666617360118)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/c877d9a5-b0d6-4865-84b1-e5a553100a2c/Untitled.png)]

  1. ODE层在初始输入分布和真实数据分布之间找到了一个平衡分布

  2. 生成了真实的假样本

    特别地,加入ODE的变换使合成样本的插值平滑。

    即,给定两个相似的初始输入,生成器生成两个相似的合成样本(如格隆沃尔-贝尔曼不等式所证明的)——实验部分展示了这些平滑的插值。

训练算法

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3n02RHqg-1666617360118)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/24677b79-0aaa-4e2c-992e-80c5b4a511b7/Untitled.png)]

实验

仿真数据的 likelihood fitness

  • 首先模拟数据集:

    • 收集了各种预先训练过的贝叶斯网络和高斯网络;
    • 使用预训练模型生成 T t r a i n T_{train} Ttrain T t e s t T_{test} Ttest
  • 评估方法:使用模拟数据的好处是可以评估对于给定预训练模型S的合成可能性。

    • T t r a i n T_{train} Ttrain训练包含OCT-GAN的生成模型;
    • 从每个训练模型生成的合成数据;记F为仿真数据
    • 测量F给定S的可能性 P r ( F ∣ S ) Pr(F|S) Pr(FS)
    • 从头开始用F训练模型S’
    • 测量 P r ( T t e s t ∣ S ′ ) Pr(T_{test}|S') Pr(TtestS)

    两种可能性估计应该足够好, P r ( T t e s t ∣ S ′ ) Pr(T_{test}|S') Pr(TtestS)的低值表示F包含模式崩溃等情况。
    在这里插入图片描述
    在表2和表3中,包含了所有的似然估计结果。CLBN和PrivBN的性能出现了波动。CLBN和PrivBN在Ring和Asia中分别表现较好,而PrivBN在Grid和Gridr中表现较差。TVAE在许多情况下对 P r ( F ∣ S ) Pr(F|S) Pr(FS)表现出良好的性能,但在Grid和Insurance中对 P r ( T t e s t ∣ S ′ ) Pr(T_{test}|S') Pr(TtestS)的性能相对较差,这意味着模式崩溃。同时,TVAE对Grid也表现出了很好的性能。总而言之,TVAE在这些实验中表现出了reasonable performance。

    在除OCT-GAN外的许多GAN模型中,TGAN和TableGAN表现出合理的性能,其他GAN在许多情况下不如它们。例如Insurance数据集的 P r ( T t e s t ∣ S ′ ) Pr(T_{test}|S') Pr(TtestS) , -14.3 for TableGAN,-14.8 for TGAN -18.1 for VEEGAN。(However, all these models are significantly outperformed by our proposed OCT-GAN. In all cases, OCT-GAN is better than TGAN, the state-of-the-art GAN model.)

真实数据的分类任务

数据集:Adult, Census, Covertype, Credit, Intrusion

Adult: 从美国1994年人口普查调查中提取的不同人口统计信息,预测了两类高收入(>5万美元)和低收入(≤5万美元)收入

Census: 与Adult相似,具有不同列

Covertype: 从制图变量中预测森林覆盖类型,并收集自科罗拉多州北部的罗斯福国家森林

Credit: 用于信用卡欺诈检测,于2013年9月从欧洲持卡人处收集

Intrusion: 被应用于国际知识发现和数据挖掘竞赛中,其中包含了许多网络入侵检测样本

Adult Census Credit 二元分类,而其他是多分类

如图4所示,除了OCT-GAN外的各种方法显示出不合理的准确性。

  • 评估方法:

    1)首先训练各种生成模型,包括OCT-GAN

    2)用训练的模型生成仿真数据F

    3)用假数据训练Adaboost、DecisionTree and MLP

    4)test with T t e s t T_{test} Ttest

    在这里插入图片描述

    除TGAN和OCT-GAN外,许多GAN模型在许多情况下都显示出较低的得分。许多GAN模型在许多情况下都显示出较低的得分。在Census中,VEEGAN的F-1得分为0.094。

回归实验

数据集:News(UCI **Online News Popularity Data Set)**它包含了从在线新闻文章中提取的许多特征来预测社交网络中的分享数量,例如,推文、转发等等。很好地展示了该方法在基于web的应用程序中的有效性

使用线性回归和MLP作为基础回归模型,并使用𝑅2作为评价度量

在这里插入图片描述

除OCT-GAN外,所有方法的精度都不合理。用 T t r a i n T_{train} Ttrain训练的原始模型显示了一个𝑅2分数为0.14,而OCT-GAN显示了一个接近它的分数。只有OCT-GAN 和原来标有 T t r a i n T_{train} Ttrain标记的模型则显示出正分。

聚类实验

使用了5个分类实验的数据集;

K = ∣ C ∣ , ∣ 2 C ∣ , ∣ 3 C ∣ K={|C|,|2C|,|3C|} K=C,∣2C,∣3C, C C C是一组类标签,对假数据 F F F运行 K − M e a n s + + K-Means++ KMeans++,选择一个得到最高的剪影轮廓得分的𝐾值。通过假数据 F F F的质心,计算应用不同GAN的 T t r a i n T_{train} Ttrain T t e s t T_{test} Ttest的Silhouette score

在这里插入图片描述
OCT-GAN在几乎所有情况下都优于TGAN。

噪声向量插值

为了进一步展示基于ode的转换在生成器中的有效性,在Adult中可视化了几个插值结果。选择两个噪声向量𝒛1,𝒛2,并通过𝑒𝒛1+(1−𝑒)𝒛2对多个中间向量进行插值。0<𝑒<1
在这里插入图片描述

如图,TGAN和OCT-GAN(only_D)表现出相似的插值模式,而OCT-GAN可以以平滑的方式进行插值。

消融实验

  1. 在OCT-GAN(fixed),不训练 t i t_i ti而将其设置为等间距的结点
  2. 只向生成器添加ODE层,即 D ( x ) = F C 5 ( l e a k y ( F C 4 ( l e a k y ( F C 3 ( h ( 0 ) ) ) ) ) ) D(x)=FC5(leaky(FC4(leaky(FC3(h(0)))))) D(x)=FC5(leaky(FC4(leaky(FC3(h(0))))))
  3. 只向鉴别器增加ODE层,即直接 z ⨁ c z\bigoplus c zc输入生成器

在表2和表3中,这些消融研究模型令人惊讶地显示了比完整模型OCT-GAN更好的似然估计
在这里插入图片描述

在Adult数据上,OCT-GAN(only_G)得分比其他模型要低得多。由此可得,在Adult数据上,鉴别器中的ODE层起着关键作用。

结论

一般来说,简单的模型,如PrivBN,TVAE和消融研究模型,显示出更好的似然估计,而复杂的模型显示出更好的机器学习任务分数;In real-world environments, however, we think that task-specific data utility is more important than likelihood. Therefore, OCT-GAN can benefit many applications.

所有的方法都没有显示出接近于标记为 T t r a i n T_{train} Ttrain的原始模型的分数,这说明了数据合成的难度。它们都是多类分类数据集。作者认为,对于复杂的机器学习任务,数据合成的质量(效用)还有一个提高的空间。

这篇关于【论文泛读】OCT-GAN(WWW’21)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/834344

相关文章

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

【LabVIEW学习篇 - 21】:DLL与API的调用

文章目录 DLL与API调用DLLAPIDLL的调用 DLL与API调用 LabVIEW虽然已经足够强大,但不同的语言在不同领域都有着自己的优势,为了强强联合,LabVIEW提供了强大的外部程序接口能力,包括DLL、CIN(C语言接口)、ActiveX、.NET、MATLAB等等。通过DLL可以使用户很方便地调用C、C++、C#、VB等编程语言写的程序以及windows自带的大

论文翻译:ICLR-2024 PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS https://openreview.net/forum?id=KS8mIvetg2 验证测试集污染在黑盒语言模型中 文章目录 验证测试集污染在黑盒语言模型中摘要1 引言 摘要 大型语言模型是在大量互联网数据上训练的,这引发了人们的担忧和猜测,即它们可能已

OmniGlue论文详解(特征匹配)

OmniGlue论文详解(特征匹配) 摘要1. 引言2. 相关工作2.1. 广义局部特征匹配2.2. 稀疏可学习匹配2.3. 半稠密可学习匹配2.4. 与其他图像表示匹配 3. OmniGlue3.1. 模型概述3.2. OmniGlue 细节3.2.1. 特征提取3.2.2. 利用DINOv2构建图形。3.2.3. 信息传播与新的指导3.2.4. 匹配层和损失函数3.2.5. 与Super

BERT 论文逐段精读【论文精读】

BERT: 近 3 年 NLP 最火 CV: 大数据集上的训练好的 NN 模型,提升 CV 任务的性能 —— ImageNet 的 CNN 模型 NLP: BERT 简化了 NLP 任务的训练,提升了 NLP 任务的性能 BERT 如何站在巨人的肩膀上的?使用了哪些 NLP 已有的技术和思想?哪些是 BERT 的创新? 1标题 + 作者 BERT: Pre-trainin

[论文笔记]LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale

引言 今天带来第一篇量化论文LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale笔记。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 大语言模型已被广泛采用,但推理时需要大量的GPU内存。我们开发了一种Int8矩阵乘法的过程,用于Transformer中的前馈和注意力投影层,这可以将推理所需

【JavaScript】LeetCode:21-25

文章目录 21 最大子数组和22 合并区间23 轮转数组24 除自身以外数组的乘积25 缺失的第一个正数 21 最大子数组和 贪心 / 动态规划贪心:连续和(count)< 0时,放弃当前起点的连续和,将下一个数作为新起点,这里提供使用贪心算法解决本题的代码。动态规划:dp[i]:以nums[i]为结尾的最长连续子序列(子数组)和。 dp[i] = max(dp[i - 1]

react笔记 8-21 约束性 表单

1、约束性组件和非约束性组件 非约束性组件<input type="text" name="" defaultValue={this.state.msg}></input>这里他的value是用户输入的值 并没有执行操作 只是获取到了msg的值 用户输入不会改变数据非约束性组件需要使用defaultValue获取数据 否则会报错约束性组件<input type="text