文献阅读报告 - Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks...

本文主要是介绍文献阅读报告 - Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks...,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1616500-20190817220921995-982102610.png

  1. paper:Gupta A , Johnson J , Fei-Fei L , et al. Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks[J]. 2018.
  2. code:https://github.com/agrimgupta92/sgan


概览

文章提出了一种采用GAN架构进行训练的轨迹预测模型,Generator由Encoder-Decoder结构组成,Discriminator由Decoder组成,旨在从合理性、多样性和预测速度等多方面对现有模型进行提升。

解决问题点

  1. 符合社会规范的轨迹:关注了预测生成轨迹在社会规则上的可行性,在定性评估上相交其他模型生成路径更合理。
  2. 多样化的轨迹:传统评估模型时采用ADE和FDE指标,优化模型的量化评估虽好,但其往往导致温和单一的预测轨迹,这与现实场景中轨迹的多样化情况不符。
  3. 预测速度提升:Vanilla LSTM vs SGAN vs Soicla LSTM :56x vs 16x vs 1x,速度有了明显提升。


模型创新点

  1. 提出新的损失函数-Variety Loss:借鉴于Minimum Over N损失函数,该损失函数鼓励Generator生成多条可行的路径。- 多样化轨迹
  2. 提出新的池化模型:模型中的池化用于LSTMs交换信息,SGAN将Social LSTM模型每步池化变为已知轨迹变化阶段仅一次池化(预测阶段默认每步都进行池化),同时将池化范围由固定局部范围拓展至全局所有行人。- 符合社会规范的轨迹预测速度提升
  3. 将GAN模型应用在轨迹预测的序列生成任务上:GAN在视觉处理上已有大量使用,但对于自然语言处理等序列模型涉及较少,主要是因为生成器向判别器传递输出的操作是不可微的。


阅读疑问

  1. 文中提到影响GAN在序列模型领域的应用原因是生成器向判别器的操作是不可微的,为何?
  2. SGAN生成器的最终输出是Decoder的隐藏状态经MLP(多层感知机)得到的二维坐标轨迹,但Social LSTM中预测二维坐标是基于隐藏状态满足二维高斯分布,SGAN没有采用这样的假设是因为该方法在反向传播时不可微,为何?

2019.8.22 更新

经过阅读一些知乎上的文章,对于上述两个问题有了初步的解答:

  1. GAN作者早起就已有提及,GAN只适用于连续型数据的生成,对于离散型数据效果不佳。
  2. GAN网络在训练生成器(Generator)时,损失函数是在判别器处计算的,从数据流向上是(数据) -> (生成器权重) -> (判别器权重) -> (Loss)只是生成器权重可训练而判别器权重不可训练。若有使用反向传播更新权重,则整个运算过程必须是可导的
  3. 对于问题一:序列问题如NLP,常常在生成结果时有采样(sample)的行为,例如经过softmax得到词向量的概率,再将概率最大的置位1其余为0表示最终预测的单词,这个概率离散化的过程就是采样,是不可以从数学上求导的。
  4. 对于问题二:Social LSTM训练阶段直接基于二维高斯分布使用neg log-likehood计算损失,在生成阶段是基于二维高斯分布随机多次采样求均值得到最终位置。如要用到GAN网络上,传递给判别器(Discriminator)的数据须使用生成阶段的采样方法,但这种方法是不可导的。

link:https://zhuanlan.zhihu.com/p/29168803


SGAN模型整体架构

GAN与cGAN

GAN中文又称生成对抗式网络,是Goodfellow等人提出的一种方式,旨在最大化训练数据的可能性下界,其中包含较多的数学原理与推导,笔者在此不具体叙述,只在实现层面简述GAN的几个特征:

  1. GAN的组成部分:GAN由生成器和判别器组成,但并不要求生成器与判别器要由神经网络组成,也可以是其他的数学模型。因此GAN实际为一个训练的框架,其中实体因实际情况而异,例如在具体的Social GAN模型中,生成器和判别器均为神经网络,并在生成时采用Encoder-Decoder结构,判别时采用Encoder结构,核心属于序列模型

    \[\min_{G}\max_{D}V(G,D) = E_{x \sim p_{data}}[logD(x)]+E_{x \sim p(z)}[log(1- D(G(z)))]\]

  2. cGAN:基础的GAN网络中,生成器生成的结果是基于随机初始化的输入向量(例如LSTM模型中,输入因为随机初始化的Hidden State),但是该网络的目标是基于已知的轨迹生成预测轨迹,因此生成器的输入还需根据已有信息合成。

    下面的SGAN结构图中,在Generator中若要再细致一些的话,真正的生成器是由LSTN组成的Decoder部分,前段和中段的Encoder和Pooling Module实为为Decoder准备其初始化Hidden State的预处理部件。

1616500-20190817221005644-2069102379.png
  1. GAN的训练过程:GAN训练时的对象生成器和判别器,而测试时对象仅有生成器。
    1. 一次迭代(epoch/iteration)中,生成器和判别器将分别经过g_stepsd_steps步训练,每次迭代中,先单独训练判别器的d_steps次,再单独训练生成器。
    2. 训练判别器:每步训练中,对于同一段已知路径,判别器将接受来自数据库和生成器的真轨迹与假轨迹,并对两个轨迹真假性做出评估,对抗损失函数将基于判别器对于两个轨迹的判断
    3. 训练生成器:每步训练中,生成器将根据一段已知路径生成假轨迹,并交由判断器判断真假,对抗损失函数将基于判别器对假轨迹的判断


SGAN结构

Social GAN分为Generator和Discriminator:

1616500-20190817221005644-2069102379.png
  1. Generator:生成器由Encoder、Pooling Module和Decoder组成。

    1. Encoder使用LSTM序列模型实现,用于将行人的历史轨迹信息编码。最终输出的隐藏状态\(h_{ei}^{t_{obs}}\),将包含整个轨迹的信息。

      \[e_i^t = \phi (x_i^t, y_i^t, W_{ee})\]

      \[h_{ei}^t = LSTM(h_{ei}^{t-1}, e_i^t;W_{encoder})\]

    2. Pooling Module使用max pooling实现,用于共享行人间信息。最终输出的是\(c_i^t\),作为Decoder输入的一部分。

      \[P_i = PM(h_{e1}^{t_{obs}},h_{e2}^{t_{obs}},h_{e3}^{t_{obs}}...)\]

      \[c_i^t = \gamma (P_i, h_{ei}^{t_{obs}};W_c)\]

      *\(\gamma(.)\)是使用Relu的多层感知机(含有多个隐藏层的全连接层)

    3. Decoder使用LSTM序列模型实现,用于生成预测的轨迹。不同于其他LSTM,其Hidden State初始值并不随机,而是由\(h_{di}^t = [c_i^t, z]\)拼接而成,前者为PM生成的结果,后者是加入的随机噪音以便生成多种轨迹。Decoder实际可被看做是带输入条件的生成器。

      这其中需要:注意实验默认的Pooling Module在Decode阶段每步运行都会进行池化。

      \[e_i^t = \phi(x_i^{t-1}, y_i^{t-1}, W_{ed})\]

      \[P_i = PM(h_{d1}^{t-1},...,h_{dn}^{t-1})\]

      \[h_{di}^t = LSTM(\gamma(P_i,h_{di}^{t-1}),e_i^t,W_{decoder})\]

      \[(\hat{x_i^t},\hat{y_i^t}) = \gamma(h_{di}^t)\]

      *\(\gamma(.)\)是使用Relu的多层感知机

  2. Discriminator:判别器结构相对简单,由一个LSTM实现的Decoder和对[Decoder输出, 已知轨迹部分]进行多层感知的全连接层组成,最终输出对于路径真假性的评分。


模型特点与创新

损失函数

SGAN模型训练是分别针对生成器和判别器的,因而两部分的损失函数也需要分别定义,SGAN的损失函数基础量是Adversarial Loss,除此之外还附加了Variety Loss增加路径生成的多样性

  1. 生成器

    \[L_G = L_{adversarial}+L_{variety}\]

    1. \(L_{adversarial}\):惩罚“生成的轨迹被判别器判为假”:判别器对轨迹的scores与[0]向量的交叉熵。
    2. \(L_{variety} = \min_k||Y_i - \hat Y_i^{(k)}||_2\):这是基于\(L_2\)损失改进的,k指代Generator中在生成Decoder的初始隐藏状态时,\(z\)的随机取样次数。按原文来讲,该函数只惩罚\(L_2\)误差最小的预测路径,鼓励“hedge its bets”(多下注,留退路),生成多种可行的路径。(与MoN损失函数类似,但并未在此领域使用过)。
  2. 判别器

    \[L_D = L_{adversarial}\]

    1. 惩罚“生成的轨迹被判别器判为真”:判别器对轨迹的scores与[0.7-1.2]向量的交叉熵。
    2. 惩罚“真实的轨迹被判别器判为假”:判别器对轨迹的scores与[0]向量的交叉熵。


池化模块

SGAN提出了异于Social Pooling的新型池化模型,这种池化模型将全局行人的信息纳入考量,并且源信息在LSTMs的Hidden States基础上增加了行人间的位置信息。后续的实验结果表明新的池化模型在量化指标上稍逊Social Pooing,但生成轨迹更符合社会规则。

1616500-20190817221049479-752891060.png

此处有几点需要注意:

  1. Pooing Module的输入由[Hidden States, Relative Location]两部分组成——每个LSTMs的隐藏状态和其他行人对目标行人的相对位置x,原文中在两处分别提到了这两个数据源,但并没有统一结合说明。
  2. 由于不同场景的人数不相同,模型为保证对池化结果维数相同,使用的是max pooling,对于每个行人(num_ped,N)的张量变为(1, N)。
  3. 实现代码中有关相对位置的计算和批量矩阵化运算的实现细节比较巧妙,如有需求请参考代码model.py - PoolHiddenNet部分和实验代码解析。


其他

  1. 路径数据相对VS绝对:虽然文章中仅在Pooling Module部分重点提出过使用相对位置(不同人在同一时刻之间),但经过通过阅读实验代码,生成器从输入到最终的输出,都是相对位置(同一人某时刻相较于前一时刻的位置变化),而绝对位置虽传入模型但仅作为计算相对位置、合成绝对位置、计算grid等功能。

  2. 生成器输出:在Social LSTM中,作者基于LSTM最终输出的隐藏状态呈现位置信息的二维高斯分布,并以此预测位置和计算损失;但在SGAN中,文章以该方法在反向传播时不可微的原因使用多层感知机直接预测二维目标,并用\(L_2\)计算损失。


模型评估与实验

  1. 实验数据库:ETH和UCY,4种场景,1536条行人轨迹,未经归一化处理。

  2. 评价指标:ADE - Average Displacement Error,FDE - Final Displacement Error。鉴于SGAN生成路径的多样性,评价时将对一条路径的多种预测取最小误差作为结果

  3. 无关变量控制

    1. 预测时间:输入3.2秒,预测3.2秒或4.8秒。
    2. SGAN实验模型的编号为SGAN-kVP-N:kV表示训练时使用Variety Loss的生成次数(1表示没有使用Variety Loss);有无P表示是否使用新型池化结构;N表示计算Error前,对于一条已知路径生成了多少条备选路径。
  4. 定量实验结论

    1. 全场最佳:SGAN模型编号SGAN-20V-20整体表现最佳,SGAN-20VP-20在量化结果上稍逊前者(后文解释)。

    2. 多样化输出显著,模型对噪音敏感:若在评估时只取模型随机生成的一个轨迹,那么量化指标结果差于Social LSTM,这表示模型对噪音\(\alpha\)是敏感的。同时,随着评估参考轨迹的数量上升,评估结果也显著提高,最高在\(k=100\)(100条轨迹中选误差最小的)时能够降低33%的错误率。

    3. 速度提升显著:得益于池化结构简化,SGAN生成速度可达Social LSTM的16倍。

      注:Social LSTM整体表现比Vanilla LSTM差,原文章的实验结果使用真实数据训练+加强数据测试的策略无法复现。

1616500-20190817221102463-536674591.png
  1. 定性实验结论

    虽然具有新型池化结构的模型比原池化的模型的数据表现略逊一筹,但将轨迹数据可视化后,新型池化的预测要比原模型更符合社会规则性。文中特别提取几种常见社交场景进行对比,具体请参见原文:

    1. 冲突场景:一对一相遇、一对多相遇、追尾式相遇、带有角度的侧面相遇。
1616500-20190817221113871-888381134.png
  1. 人群聚合场景、人群回避场景(人人间互相回避)、人群跟随场景


转载于:https://www.cnblogs.com/sinoyou/p/11370602.html

这篇关于文献阅读报告 - Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks...的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/268197

相关文章

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

【专题】2024飞行汽车技术全景报告合集PDF分享(附原数据表)

原文链接: https://tecdat.cn/?p=37628 6月16日,小鹏汇天旅航者X2在北京大兴国际机场临空经济区完成首飞,这也是小鹏汇天的产品在京津冀地区进行的首次飞行。小鹏汇天方面还表示,公司准备量产,并计划今年四季度开启预售小鹏汇天分体式飞行汽车,探索分体式飞行汽车城际通勤。阅读原文,获取专题报告合集全文,解锁文末271份飞行汽车相关行业研究报告。 据悉,业内人士对飞行汽车行业

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目

软件架构模式:5 分钟阅读

原文: https://orkhanscience.medium.com/software-architecture-patterns-5-mins-read-e9e3c8eb47d2 软件架构模式:5 分钟阅读 当有人潜入软件工程世界时,有一天他需要学习软件架构模式的基础知识。当我刚接触编码时,我不知道从哪里获得简要介绍现有架构模式的资源,这样它就不会太详细和混乱,而是非常抽象和易

开题报告中的研究方法设计:AI能帮你做什么?

AIPaperGPT,论文写作神器~ https://www.aipapergpt.com/ 大家都准备开题报告了吗?研究方法部分是不是已经让你头疼到抓狂? 别急,这可是大多数人都会遇到的难题!尤其是研究方法设计这一块,选定性还是定量,怎么搞才能符合老师的要求? 每次到这儿,头脑一片空白。 好消息是,现在AI工具火得一塌糊涂,比如ChatGPT,居然能帮你在研究方法这块儿上出点主意。是不

【干货分享】基于SSM的体育场管理系统的开题报告(附源码下载地址)

中秋送好礼 中秋佳节将至,祝福大家中秋快乐,阖家幸福。本期免费分享毕业设计作品:《基于SSM的体育场管理系统》。 基于SSM的体育场管理系统的开题报告 一、课题背景与意义 随着全民健身理念的深入人心,体育场已成为广大师生和社区居民进行体育锻炼的重要场所。然而,传统的体育场管理方式存在诸多问题,如资源分配不均、预约流程繁琐、数据统计不准确等,严重影响了体育场的使用效率和用户体验。

【阅读文献】一个使用大语言模型的端到端语音概要

摘要 ssum框架(Speech Summarization)为了 从说话人的语音提出对应的文本二题出。 ssum面临的挑战: 控制长语音的输入捕捉 the intricate cross-mdoel mapping 在长语音输入和短文本之间。 ssum端到端模型框架 使用 Q-Former 作为 语音和文本的中介连接 ,并且使用LLMs去从语音特征正确地产生文本。 采取 multi-st

生成对抗网络(GAN网络)

Generative Adversarial Nets 生成对抗网络GAN交互式可视化网站 1、GAN 基本结构 GAN 模型其实是两个网络的组合: 生成器(Generator) 负责生成模拟数据; 判别器(Discriminator) 负责判断输入的数据是真实的还是生成的。 生成器要不断优化自己生成的数据让判别网络判断不出来,判别器也要优化自己让自己判断得更准确。 二者关系形成