【机器学习】基于Softmax松弛技术的离散数据采样

2024-06-22 17:52

本文主要是介绍【机器学习】基于Softmax松弛技术的离散数据采样,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.引言

1.1.离散数据采样的意义

离散数据采样在深度学习中起着至关重要的作用,它直接影响到模型的性能、泛化能力、训练效率、鲁棒性和解释性。

首先,采样方法能够有效地平衡数据集中不同类别的样本数量,使得模型在训练时能够更均衡地学习各个类别的特征,从而避免因数据不平衡导致的偏差。

其次,合理的采样策略可以确保模型在训练过程中能够接触到足够多的样本,避免过拟合和欠拟合问题,提高模型的泛化能力。

此外,通过随机选择部分样本来减少训练数据的规模,可以提高训练效率,使得深度学习模型在处理大规模数据集时更加高效。同时,离散数据采样还能增加数据集的多样性,使得模型在训练过程中能够接触到更多不同类型的样本,从而提高模型的鲁棒性,使其能够更好地适应各种实际应用场景。

最后,通过控制训练数据的分布来影响模型的决策过程,离散数据采样可以为深度学习模型提供一定的解释性,使得模型的决策过程更加可理解和可信任。

在实际应用中,选择合适的离散数据采样策略对于提高深度学习模型的性能和可解释性至关重要。

1.2.主要内容

本文探讨了如何从非结构化的向量数据中有效地采样出离散变量,并将这些变量转化为具有特定结构的实体,例如集合、序列或网络图等形式,进而将它们嵌入到可微分的模型框架中。

文章的核心在于应用连续的松弛技术来处理离散随机变量,尤其是二元和分类类型的变量。在第一部分中,我们详细介绍了利用Gumbel-Softmax技巧来实现从离散概率分布中进行采样。通过这种方法,我们成功地训练了一个变分自编码器模型,该模型具备了分类型的潜在变量。这种技巧为处理离散性和结构化数据提供了一种新颖的途径,并使得模型能够通过标准的反向传播算法进行训练和优化。

2.离散数据采样理论和实践

2.1.使用Gumbel-Argmax进行分类采样

Gumbel-Argmax,也称为Gumbel-Softmax trick,是一种在深度学习中处理离散变量的技巧,特别是在变分推断和生成模型中。这种方法允许我们从离散分布中进行可微分的采样,从而使得梯度下降算法可以应用于包含离散选择的模型。

2.1.1.基本原理

Gumbel-Argmax方法基于Gumbel分布,一种极端值分布,它可以被用来将离散的随机变量转换成连续的形式,从而便于梯度的传播。Gumbel-Argmax方法,也称为Gumbel-Softmax trick,在深度学习中是一种处理离散变量的技巧,尤其是在需要不同iable的随机采样时。这种方法利用Gumbel分布的特性,允许模型通过softmax函数进行梯度的反向传播。

Gumbel分布是一种极端值分布,常用于模拟独立随机变量的最大值或最小值。它的概率密度函数(PDF)和累积分布函数(CDF)具有以下形式:

  • PDF:
    f ( z ; μ , β ) = 1 β exp ⁡ ( − z − μ β ) exp ⁡ ( − exp ⁡ ( − z − μ β ) ) f(z; \mu, \beta) = \frac{1}{\beta} \exp\left(-\frac{z - \mu}{\beta}\right) \exp\left(-\exp\left(-\frac{z - \mu}{\beta}\right)\right) f(z;μ,β)=β1exp(βzμ)exp(exp(βzμ))
    其中 μ \mu μ是位置参数, b e t a beta beta是尺度参数。

  • CDF:
    F ( z ; μ , β ) = exp ⁡ ( − exp ⁡ ( − z − μ β ) ) F(z; \mu, \beta) = \exp\left(-\exp\left(-\frac{z - \mu}{\beta}\right)\right) F(z;μ,β)=exp(exp(βzμ))

在Gumbel-Argmax技巧中,通常将尺度参数 β \beta β设为1,以简化计算。

2.1.2.Gumbel-Argmax的工作原理

  1. Gumbel噪声的添加:对于每个离散选择的对数几率 ( x_k ),我们添加一个独立的Gumbel噪声 ( g_k ),得到 ( x_k + g_k )。

  2. Softmax归一化:将添加了噪声的对数几率通过softmax函数进行归一化,得到概率分布 ( \pi_k )。

  3. Argmax采样:在前向传播中,使用softmax得到的概率分布进行argmax操作,得到最可能的选择。在反向传播中,使用softmax的梯度进行传播。

2.1.3.梯度传播的实现

Gumbel-Argmax方法的关键优势在于它允许梯度通过离散采样过程进行传播。在反向传播时,尽管argmax操作本身不可微,但是可以通过Gumbel噪声的连续性来实现梯度的传播。这就是所谓的“直通估计器”(Straight-Through Estimator, STE)。

2.1.4.采样步骤

  1. Gumbel噪声采样:对于每个离散选择,我们首先从Gumbel分布中采样一个噪声项。Gumbel分布是一种以0为位置参数,1为尺度参数的分布。

  2. 对数几率调整:将Gumbel噪声加到原始的对数几率(logits)上,使得每个选择的值变为 logits + Gumbel噪声

  3. Softmax归一化:应用softmax函数对调整后的值进行归一化,得到一个概率分布。

  4. Argmax转换:在前向传播中,使用softmax得到的分布进行argmax操作,得到最可能的选择;在反向传播中,使用softmax的梯度进行传播。

2.1.5.采样数学模型

假设我们有一个具有 C C C个可能值的分类分布,每个值的权重为 w i ∈ ( 0 , ∞ ) w_i \in (0,\infty) wi(0,),我们的目标是从此分布中抽取一个样本,类别 c i c_i ci的概率由softmax分布决定,公式如下:

p i = exp ⁡ ( log ⁡ ( w i ) ) ∑ j exp ⁡ ( log ⁡ ( w j ) ) p_i = \frac{\exp(\log(w_i))}{\sum_{j} \exp(\log(w_j))} pi=jexp(log(wj))exp(log(wi))

Gumbel-Argmax采样方法的步骤是:

  1. 从均匀分布 U n i f o r m ( 0 , 1 ) Uniform(0,1) Uniform(0,1)中独立同分布地采样 U k U_k Uk,然后计算 r k = log ⁡ ( w i ) − log ⁡ ( − log ⁡ U k ) r_k = \log(w_i) - \log(-\log U_k) rk=log(wi)log(logUk)
  2. 选择使得 r k r_k rk最大的索引 i i i(即执行argmax操作),并返回一个1-hot编码的向量,其中第 i i i位为1,其余位置为0。
    添加到 r k r_k rk中的噪声项 − l o g ( − log ⁡ U k ) -log(-\log U_k) log(logUk)遵循Gumbel分布,这也是该方法名称的由来。Gumbel分布(位置参数为0,尺度参数为1)的累积分布函数定义为:

F ( z ) = exp ⁡ ( − exp ⁡ ( − z ) ) F(z) = \exp(-\exp(-z)) F(z)=exp(exp(z))

备注:有关这种方法确实能从softmax分布中采样的证明,如下:

在神经网络、广义线性模型、主题模型以及许多其他概率模型中,人们常常希望用一个无约束的向量来参数化一个离散分布,即一个不受单纯形限制、可以是负数等的向量。解决这个问题的一个非常常见的方法是使用“softmax”变换:
π k = exp ⁡ { x k } ∑ k ′ = 1 K exp ⁡ { x k ′ } \pi_k = \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}} πk=k=1Kexp{xk}exp{xk}其中 x k x_k xk R \mathbb{R} R 中是无约束的,但是 π k \pi_k πk 位于单纯形上,即 π k ≥ 0 \pi_k \geq 0 πk0 ∑ k π k = 1 \sum_{k}\pi_k=1 kπk=1 x k x_k xk 参数化了一个离散分布(不是唯一的),我们可以通过执行softmax变换然后进行通常的抽样来生成数据。有趣的是,实际上存在一种替代方法来获得这样的离散样本,而不需要构建离散分布。

这种方法是softmax-离散过程的等价物:向每个 x k x_k xk 添加Gumbel噪声,然后取argmax。也就是说,向每个 x k x_k xk 添加独立的噪声,然后进行最大值操作。这并没有改变算法的渐近复杂度,但是为一些有趣的实现可能性打开了大门。这是如何工作的呢?具有单位尺度和位置参数 μ \mu μ 的Gumbel分布具有以下概率密度函数(PDF):
f ( z ; μ ) = exp ⁡ { − ( z − μ ) − exp ⁡ { − ( z − μ ) } } . f(z\,;\,\mu) = \exp\{-(z-\mu) - \exp\{-(z-\mu)\}\}. f(z;μ)=exp{(zμ)exp{(zμ)}}.Gumbel的累积分布函数(CDF)是
F ( z ; μ ) = exp ⁡ { − exp ⁡ { − ( z − μ ) } } . F(z\,;\,\mu) = \exp\{-\exp\{-(z-\mu)\}\}. F(z;μ)=exp{exp{(zμ)}}.现在,假设我们的第 k k k 个Gumbel,位置参数为 x k x_k xk,结果为 z k z_k zk。所有其他的 z k ′ ≠ k z_{k'\neq k} zk=k 小于这个值的概率是
Pr ⁡ ( z k is largest ∣ z k , { x k ′ } k ′ = 1 K ) = ∏ k ′ ≠ k exp ⁡ { − exp ⁡ { − ( z k − x k ′ ) } } . \Pr(z_k \text{ is largest}\,|\, z_k, \{x_{k'}\}^K_{k'=1}) = \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\}. Pr(zk is largestzk,{xk}k=1K)=k=kexp{exp{(zkxk)}}.我们知道 z k z_k zk 的边缘分布,我们需要积分它来找到整体概率:
Pr ⁡ ( k is largest ∣ { x k ′ } ) = ∫ exp ⁡ { − ( z k − x k ) − exp ⁡ { − ( z k − x k ) } } × ∏ k ′ ≠ k exp ⁡ { − exp ⁡ { − ( z k − x k ′ ) } } d z k . \Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) = \\ \int \exp\{-(z_k-x_k)-\exp\{-(z_k-x_k)\}\}\times\\ \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\} \,\mathrm{d}z_k. Pr(k is largest{xk})=exp{(zkxk)exp{(zkxk)}}×k=kexp{exp{(zkxk)}}dzk.通过一些代数运算,我们得到:
Pr ⁡ ( k is largest ∣ { x k ′ } ) = exp ⁡ { x k } ∑ k ′ = 1 K exp ⁡ { x k ′ } . \Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) = \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}}. Pr(k is largest{xk})=k=1Kexp{xk}exp{xk}.我们可以看到,这正是softmax概率.

简而言之,使用Gumbel重新参数化技巧对分类变量进行采样的步骤如下:

  1. 给定权重 w i w_i wi,计算 r i = w i + g i r_i = w_i + g_i ri=wi+gi,其中 g i g_i gi是从Gumbel分布中独立同分布采样得到的。
  2. 执行Argmax操作:返回最大的 r i r_i ri对应的索引,作为1-hot向量。

3.Softmax松弛技术

在深度学习中,当我们需要从一个分布中选取一个类别时,通常会使用softmax函数来得到每个类别的概率分布,然后使用argmax来选取概率最大的类别。然而,由于argmax操作在大多数深度学习框架中是不可导的,这就使得在训练过程中无法使用基于梯度的优化算法。

为了解决这个问题,我们采用softmax函数来近似argmax,因为它本身是连续且可导的。但是,原始的softmax输出是一个概率分布,而argmax输出的是一个离散值(即类别的索引)。为了控制softmax输出与1-hot向量的接近程度,我们引入了温度参数 τ \tau τ(希腊字母tau)。

引入温度参数 τ \tau τ后的softmax函数可以写作:

p i = exp ⁡ ( r i / τ ) ∑ j exp ⁡ ( r j / τ ) p_i = \frac{\exp(r_i / \tau)}{\sum_j \exp(r_j / \tau)} pi=jexp(rj/τ)exp(ri/τ)

其中, p i p_i pi是第 i i i个类别的概率, r i r_i ri是该类别的原始得分(在深度学习模型中通常是模型的输出层的线性变换结果), τ \tau τ是温度参数。

当温度参数 τ \tau τ较小时,softmax的输出会更加接近一个1-hot向量,即最大概率的类别概率接近1,而其他类别的概率接近0。这使得softmax的输出更接近于argmax的结果。

相反,当温度参数 τ \tau τ较大时,softmax的输出会更加平滑,即所有类别的概率都会相对均匀,没有哪一个类别的概率特别突出。这有助于在训练初期鼓励模型探索不同的类别,避免过早地陷入局部最优解。

因此,通过调整温度参数 τ \tau τ,我们可以控制softmax输出与1-hot向量的距离,从而实现对argmax操作的可导近似。在训练过程中, τ \tau τ可以是一个固定的值,也可以是一个可学习的参数,根据具体的任务和数据集进行调整。

温度参数对分布和样本的影响可以通过下图(图引用自论文:CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX)观察到。
在这里插入图片描述

在训练过程中,可以通过调整 τ \tau τ来控制模型的“柔软度”。开始时,可以使用较大的 τ \tau τ值来帮助模型在类别之间进行探索,随着训练的进行,逐渐减小 τ \tau τ值以鼓励模型做出更确定的预测。

然而,如果你提到的是Gumbel-Softmax采样,这是一种用于离散潜变量(如分类变量)的可导近似方法。Gumbel-Softmax允许你通过softmax函数从离散分布中采样,同时保持整个过程的可导性。这种方法通过添加Gumbel噪声到原始logits(未归一化的概率)上,然后使用softmax进行归一化,并通过一个温度参数来控制分布的离散程度。

Gumbel-Softmax采样的基本步骤如下:

  1. 对于每个logits x i x_i xi,从其对应的Gumbel分布中抽取一个样本 g i g_i gi
  2. 计算 r i = x i + g i r_i = x_i + g_i ri=xi+gi
  3. 应用带有温度参数 τ \tau τ的softmax函数,得到 softmax τ ( r ) \text{softmax}_\tau(r) softmaxτ(r)
  4. 最后,通常使用softmax输出的概率作为权重,从类别中进行采样(在训练时通常使用softmax输出本身作为近似,而在测试时可能使用argmax或采样)。

注意,在训练过程中,Gumbel-Softmax通常用于反向传播梯度,因为它是可导的。然而,在评估或测试模型时,你可能想要一个离散的输出,这时可以使用argmax或基于softmax输出的采样方法。

4.分类变分自编码器(Categorical VAE)

我们编写了一个Gumbel-Softmax技巧应用的案例,呈现了一个变分自编码器(VAE),它专门设计用于处理MNIST数据集,并拥有一个由分类变量构成的潜在空间。这个潜在空间由多个分类变量组成,每个变量都能够取有限个可能的类别。

具体来说,在这个例子中,我们的潜在空间由30个独立的分类变量组成,每个变量都限定在10个可能的类别中。鉴于VAE模型的工作原理,我们需要为这些潜在变量定义一个先验分布,这里我们选择了一个均匀的分类分布作为先验,意味着在潜在空间中每个类别的出现概率是相同的。

通过这种方式,我们能够利用Gumbel-Softmax分布来实现对潜在变量的可微分采样,从而允许模型通过标准的反向传播算法进行训练。这种采样方法不仅提供了一种灵活的方式来处理离散数据,而且还保持了模型在概率建模上的理论完整性。

4.1.设置

我们从所需的导入和超参数定义开始。

import numpy as np  # 导入NumPy库,用于数学运算import torch  # 导入PyTorch库,用于构建和训练神经网络
import torch.nn.functional as F  # 导入PyTorch的功能性模块,包含一些神经网络中常用的函数
from torch import nn, optim  # 从PyTorch中导入神经网络模块(nn)和优化器模块(optim)
from torch.nn import functional as F  # 导入PyTorch神经网络的功能性方法,F是函数库的别名
from torchvision import datasets, transforms  # 从torchvision库中导入数据集和转换模块,用于加载和预处理数据
from torchvision.utils import save_image  # 从torchvision.utils中导入保存图像的函数
from torch.distributions.one_hot_categorical import OneHotCategorical  # 从PyTorch的分布库中导入OneHotCategorical分布import matplotlib  # 导入matplotlib库,用于绘图
import matplotlib.pyplot as plt  # 导入matplotlib的pyplot模块,用于创建图表
%matplotlib inline  # 使得matplotlib的图表可以在Jupyter笔记本中直接显示cuda=True  # 设置是否使用CUDA(GPU加速),如果设置为True且GPU可用,代码将使用GPU进行加速
batch_size = 100  # 设置每个批次的样本数量为100
epochs = 10  # 设置训练的轮数(epoch)为10
latent_dim = 30  # 设置潜在空间的维度为30
categorical_dim = 10  # 设置分类潜在变量的类别数量为10
temp = 1.0  # 设置Gumbel-Softmax采样中的温度参数为1.0,控制采样的随机性

4.2.Gumbel采样

我们现在转向实现Gumbel-Softmax采样的过程,这是一种在深度学习中处理离散变量的策略,允许模型利用梯度下降算法进行训练。以下是三个关键函数的定义,它们共同构成了Gumbel-Softmax采样方法:

  1. sample_gumbel 函数
    这个函数用于生成Gumbel分布的样本。它首先从均匀分布 U ( 0 , 1 ) U(0,1) U(0,1) 中抽取随机数,然后通过计算 − log ⁡ ( − log ⁡ ( U ( 0 , 1 ) ) ) -\log(-\log(U(0,1))) log(log(U(0,1))) 来得到Gumbel分布的样本。这些样本是按照尺度为0、位置参数为1的Gumbel分布生成的。

  2. gumbel_softmax_sample 函数
    该函数负责将Gumbel噪声添加到未经归一化的对数概率(即logits)上,然后通过设置一个温度参数来控制softmax函数的平滑程度,最终应用softmax函数来获取概率分布。

  3. gumbel_softmax 函数
    这个函数结合了上述的采样和softmax操作,并增加了评估模式下的行为。在评估模式下,它直接从由logits定义的分类分布中采样,而不会进行Gumbel-Softmax松弛。这允许我们在模型评估时获得确定性的样本。

这些函数的实现为深度学习模型中离散变量的处理提供了一种有效的方法,使模型能够通过连续的放松来优化通常不可微的离散采样步骤。通过这种方式,我们可以训练通常难以优化的模型,并且能够处理更广泛的数据类型和结构。

def sample_gumbel(shape, eps=1e-20):# 函数用于生成Gumbel分布的样本U = torch.rand(shape)  # 生成与shape相同形状的[0,1)之间的均匀分布随机数if cuda:U = U.cuda()  # 如果cuda为True,将数据转移到GPU上return -torch.log(-torch.log(U + eps) + eps)  # 通过变换得到Gumbel分布的样本,eps用于避免对数为负无穷def gumbel_softmax_sample(logits, temperature):# 函数用于在给定的对数几率(logits)和温度参数下,通过Gumbel-Softmax技巧采样y = logits + sample_gumbel(logits.size())  # 将Gumbel分布的样本加到logits上return F.softmax(y / temperature, dim=-1)  # 应用softmax函数,并除以温度参数,得到概率分布def gumbel_softmax(logits, temperature, evaluate=False):# 函数用于在评估模式下进行Gumbel-Softmax采样或者进行训练if evaluate:# 如果是在评估模式下,直接从分类分布中采样d = OneHotCategorical(logits=logits.view(-1, latent_dim, categorical_dim))  # 创建OneHotCategorical分布return d.sample().view(-1, latent_dim * categorical_dim)  # 采样并重塑形状# 如果不是评估模式,使用Gumbel-Softmax技巧进行采样y = gumbel_softmax_sample(logits, temperature)  # 调用上面定义的采样函数return y.view(-1, latent_dim * categorical_dim)  # 重塑采样结果的形状

这些函数中使用了一些PyTorch的函数和类,例如torch.rand生成均匀分布的随机数,F.softmax应用softmax函数,以及OneHotCategorical分布用于从分类变量中采样。cuda变量用于判断是否使用GPU加速运算。logits是分类概率的对数,temperature参数控制了采样的随机性,evaluate参数用于判断当前是在训练模式还是评估模式。在评估模式下,直接采样并返回离散的样本;在训练模式下,使用Gumbel-Softmax技巧进行梯度估计和优化。

4.3.VAE模型

我们现在转向构建一个变分自编码器(VAE)模型,该模型采用了Gumbel-Softmax技巧来处理潜在空间中的离散变量:

  1. 模型定义
    我们创建了一个名为 VAE_gumbel 的类,它基于PyTorch的 nn.Module 构建。这个类实现了一个VAE模型,其中潜在变量通过Gumbel-Softmax分布进行采样,从而允许梯度下降算法的应用。

  2. 初始化 (__init__ 方法)
    在类的构造函数中,我们初始化了模型所需的网络层,包括线性层、ReLU激活函数和Sigmoid激活函数。这些层构成了模型的编码器和解码器部分。

  3. 编码过程 (encode 方法)
    encode 方法负责将输入数据 x 转换成潜在变量的对数几率。这一过程涉及多个线性层和非线性激活函数,最终输出潜在变量的未归一化对数概率。

  4. 解码过程 (decode 方法)
    decode 方法接收潜在变量作为输入,并通过一个由线性层和激活函数组成的网络结构来重建输入数据。最终,该方法输出重构数据的概率分布。

  5. 前向传播 (forward 方法)
    forward 方法实现了模型的前向传播。它首先通过编码器获取潜在变量的对数几率,然后根据提供的 temp 温度参数和 evaluate 标志,使用Gumbel-Softmax技巧或确定性采样来生成潜在变量。最后,解码器将潜在变量转换为重构数据。

  6. 温度参数和评估标志

    • temp 参数控制Gumbel-Softmax采样的随机性。较低的温度值使得采样更接近确定性选择,而较高的温度值则增加随机性。
    • evaluate 标志用于确定是否在评估模式下运行模型。在评估模式下,模型使用确定性采样来直接从潜在变量的分布中获取样本,以便进行模型的评估和测试。

通过这种方式,VAE_gumbel 类提供了一种灵活的方法来训练和评估VAE模型,同时处理潜在空间中的离散性质。这种模型特别适用于那些需要学习离散潜在表示的任务,例如处理分类数据或进行结构化数据建模。

class VAE_gumbel(nn.Module):def __init__(self, temp):super(VAE_gumbel, self).__init__()  # 调用基类的初始化方法# 定义模型的层self.fc1 = nn.Linear(784, 512)  # 定义一个线性层,输入维度784,输出维度512self.fc2 = nn.Linear(512, 256)  # 定义一个线性层,输入维度512,输出维度256self.fc3 = nn.Linear(256, latent_dim * categorical_dim)  # 定义一个线性层,输出维度为潜在维度乘以分类维度self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)  # 定义一个线性层,输入维度为潜在维度乘以分类维度self.fc5 = nn.Linear(256, 512)  # 定义一个线性层,输入维度256,输出维度512self.fc6 = nn.Linear(512, 784)  # 定义一个线性层,输出维度784# 定义激活函数self.relu = nn.ReLU()  # ReLU激活函数self.sigmoid = nn.Sigmoid()  # Sigmoid激活函数def encode(self, x):# 定义编码器h1 = self.relu(self.fc1(x))  # 通过线性层和ReLU激活函数h2 = self.relu(self.fc2(h1))  # 通过第二个线性层和ReLU激活函数return self.relu(self.fc3(h2))  # 通过第三个线性层和ReLU激活函数得到潜在变量的对数几率def decode(self, z):# 定义解码器h4 = self.relu(self.fc4(z))  # 通过线性层和ReLU激活函数h5 = self.relu(self.fc5(h4))  # 通过第二个线性层和ReLU激活函数return self.sigmoid(self.fc6(h5))  # 通过第三个线性层和Sigmoid激活函数得到重构图像的概率def forward(self, x, temp, evaluate=False):# 定义前向传播过程q = self.encode(x.view(-1, 784))  # 对输入x进行编码,得到潜在变量的对数几率q_y = q.view(q.size(0), latent_dim, categorical_dim)  # 重塑编码后的形状z = gumbel_softmax(q_y, temp, evaluate)  # 使用Gumbel-Softmax技巧进行采样return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())  # 解码并返回重构图像和潜在变量的概率分布

4.2.计算KL散度

在变分自编码器(VAE)的训练过程中,除了重建输入数据,模型还需要确保潜在变量的分布与先验分布保持一致。这通常通过最小化潜在分布与均匀先验分布之间的Kullback-Leibler (KL) 散度来实现。

  1. KL散度的计算
    VAE模型中的KL散度衡量了潜在变量的概率分布 $ q(x) $ 与均匀先验分布 $ p(x) = \frac{1}{C} $ 之间的差异。对于离散的潜在变量,KL散度可以表示为:
    KLD ( q ∣ ∣ p ) = ∑ i = 1 C q ( x i ) log ⁡ ( C ⋅ q ( x i ) 1 ) \text{KLD}(q||p) = \sum_{i=1}^{C} q(x_i) \log \left(\frac{C \cdot q(x_i)}{1}\right) KLD(q∣∣p)=i=1Cq(xi)log(1Cq(xi))
    其中, C C C 是潜在变量可能取值的总数, q ( x i ) q(x_i) q(xi) 是潜在变量取第 i i i 个值的概率。

  2. 重构损失
    除了KL散度,VAE的损失函数还包括重构损失,它通常采用二元交叉熵(Binary Cross-Entropy, BCE)来衡量模型重构的图像 recon _ x \text{recon}\_x recon_x 与原始输入图像 x x x 之间的差异。

  3. VAE损失的组成
    VAE的总损失是重构损失和KL散度的结合,可以表示为:
    VAE Loss = BCE ( recon x , x ) + KLD ( q ∣ ∣ p ) \text{VAE Loss} = \text{BCE}(\text{recon}_x, x) + \text{KLD}(q||p) VAE Loss=BCE(reconx,x)+KLD(q∣∣p)
    这种损失函数的设计旨在使模型在保持数据重构精度的同时,也让潜在变量的分布接近于先验分布。

  4. 损失函数的作用
    通过最小化这个损失函数,VAE模型学习到如何将输入数据有效地编码到潜在空间,并且在这个空间中探索数据的分布。KL散度正则化确保了潜在变量的分布不会偏离均匀分布太远,这有助于模型学习到更加泛化的特征表示。

  5. 实现考虑
    在实际实现中,为了数值稳定性和计算效率,我们通常会对KL散度进行一些调整,例如通过减去一个常数项或使用log-trick来避免数值下溢。

通过这种方式,VAE模型不仅能够学习到数据的重构,还能够学习到数据的潜在结构,使其能够在生成任务或特征学习中发挥重要作用。

def loss_function(recon_x, x, qy):# 定义VAE的损失函数,包括重构损失和KL散度BCE = F.binary_cross_entropy(  # 计算重构损失recon_x,                  # 模型重构的图像x.view(-1, 784),           # 原始图像,调整为匹配重构图像的形状size_average=False         # 不使用大小平均,直接使用所有样本的总损失) / x.shape[0]               # 将损失平均到每个样本log_ratio = torch.log(qy * categorical_dim + 1e-20)  # 计算潜在变量分布的对数比率,1e-20用于数值稳定性KLD = torch.sum(               # 计算KL散度,即潜在变量分布与先验分布之间的差异qy * log_ratio,            # 每个潜在维度上的KL散度贡献dim=-1                     # 沿着最后一个维度(潜在变量的类别维度)求和).mean()                      # 计算所有样本的平均KL散度return BCE + KLD               # 返回总损失,即重构损失和KL散度的和

4.3.建立并训练模型

为了构建一个变分自编码器(VAE)并进行训练,我们需要定义模型结构、损失函数、优化器,以及数据加载器。

4.3.1.训练准备

在本阶段,我们的主要任务是设置变分自编码器(VAE)模型的训练环境,并准备相应的数据加载器。

  1. 模型初始化
    我们首先实例化VAE模型,这是执行数据学习和重构的核心组件。

  2. 数据加载器的准备
    利用PyTorch的DataLoader工具,我们能够高效地加载训练和测试数据。DataLoader的优势在于它支持批量数据处理、多线程加载,以及自动的数据随机化,这些都是训练过程中的重要特性。

  3. GPU加速
    通过检查CUDA的可用性,我们有条件地将模型和数据迁移到GPU,这样做可以显著提升计算速度,尤其是在处理大规模数据集或复杂模型时。

  4. 优化器配置
    我们选择了Adam优化器,这是深度学习中一个非常流行的选择,因为它自适应地调整学习率,通常能够更快地收敛。通过optim.Adam,我们初始化了优化器,并设置了一个初始学习率0.001。

  5. 损失目标的计算
    在训练过程中,我们不仅计算放松目标,也评估未放松目标。虽然未放松目标不直接用于训练,但它提供了一个基准,帮助我们理解当前模型的状态,并指导我们如何调整模型参数,特别是温度参数,以平衡训练性能和目标的接近度。

  6. 训练与评估的平衡
    通过评估放松目标与实际目标之间的差异,我们可以调整模型的温度参数,确保在训练过程中,模型既能有效学习数据的分布,又能保持潜在空间的离散性。

通过这些步骤,我们建立了一个完整的训练框架,它不仅包括模型的构建和数据的准备,还包括了训练过程中的监控和调整机制,确保模型能够在学习数据的同时,保持良好的泛化能力和生成能力。

model = VAE_gumbel(temp)  # 创建VAE模型实例,temp是温度参数
if cuda:model.cuda()  # 如果cuda为True,将模型转移到GPU上
optimizer = optim.Adam(model.parameters(), lr=1e-3)  # 创建Adam优化器,用于模型参数的优化,学习率为0.001# 设置数据加载器的参数,如果使用GPU,则使用多线程和锁定内存
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}# 创建训练数据加载器,使用MNIST数据集
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data/MNIST', train=True, download=True,  # MNIST训练数据集路径和下载选项transform=transforms.ToTensor()),  # 将图像转换为Tensorbatch_size=batch_size,  # 每个批次的样本数量shuffle=True,  # 在每个epoch开始时打乱数据**kwargs  # 根据是否使用GPU设置多线程和锁定内存的参数
)# 创建测试数据加载器,使用MNIST数据集
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data/MNIST', train=False,  # MNIST测试数据集路径transform=transforms.ToTensor()),batch_size=batch_size,  # 每个批次的样本数量shuffle=True,  # 在每个epoch开始时打乱数据**kwargs  # 根据是否使用GPU设置多线程和锁定内存的参数
)

4.3.2.训练模型

在训练阶段,我们通过编写两个关键函数traintest来对变分自编码器(VAE)模型进行训练和评估。

  1. 训练过程 (train 函数)

    • 每个epoch期间,模型遍历整个训练数据集。
    • 损失函数计算当前批次的重构误差和潜在变量分布的KL散度。
    • 执行反向传播来计算损失相对于模型参数的梯度。
    • 更新模型参数以最小化损失。
  2. 评估过程 (test 函数)

    • 在评估模式下运行模型,此时模型不会进行梯度更新。
    • 计算并输出模型在测试集上的性能,通常是通过测试损失来衡量。
  3. GPU加速

    • 使用cuda变量检查GPU是否可用,以决定是否将模型和数据迁移到GPU上,从而加速训练。
  4. 梯度管理

    • 在每次迭代前,optimizer.zero_grad()确保梯度被清零,避免累积。
    • loss.backward()根据当前损失计算参数梯度。
    • optimizer.step()根据计算出的梯度更新模型参数。
  5. 评估模式

    • evaluate=True参数指示模型在评估模式下运行,此时Gumbel-Softmax采样是确定性的,直接从潜在变量的分布中采样。
  6. 温度参数的影响

    • 训练时,可以尝试不同的温度值来观察其对放松目标与真实目标之间关系的影响,以及如何求解最接近的值。

通过这些步骤,我们能够系统地训练VAE模型,并通过调整温度参数来平衡模型的训练性能和目标接近度。这种方法不仅有助于提高模型的重构能力,还能够确保模型在潜在空间中学习到有意义的分布。

def train(epoch):# 定义训练循环model.train()  # 设置模型为训练模式train_loss = 0  # 初始化训练损失为0train_loss_unrelaxed = 0  # 初始化未放松训练目标的损失为0for batch_idx, (data, _) in enumerate(train_loader):  # 遍历训练数据加载器if cuda:data = data.cuda()  # 如果使用GPU,将数据转移到GPU上optimizer.zero_grad()  # 清除之前的梯度recon_batch, qy = model(data, temp)  # 通过模型得到重构图像和潜在变量的分布loss = loss_function(recon_batch, data, qy)  # 计算损失loss.backward()  # 反向传播,计算梯度train_loss += loss.item() * len(data)  # 累加损失optimizer.step()  # 更新模型参数# 评估未放松训练目标(不用于训练,仅用于比较)recon_batch_eval, qy_eval = model(data, temp, evaluate=True)loss_eval = loss_function(recon_batch_eval, data, qy_eval)train_loss_unrelaxed += loss_eval.item() * len(data)print('Epoch: {} Average loss relaxed: {:.4f} Unrelaxed: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset),  # 打印平均放松损失train_loss_unrelaxed / len(train_loader.dataset)))  # 打印平均未放松损失def test(epoch):# 定义评估循环model.eval()  # 设置模型为评估模式test_loss = 0  # 初始化测试损失为0for i, (data, _) in enumerate(test_loader):  # 遍历测试数据加载器if cuda:data = data.cuda()  # 如果使用GPU,将数据转移到GPU上recon_batch, qy = model(data, temp, evaluate=True)  # 通过模型得到重构图像和潜在变量的分布(评估模式)test_loss += loss_function(recon_batch, data, qy).item() * len(data)  # 累加损失test_loss /= len(test_loader.dataset)  # 计算平均测试损失print('Eval loss: {:.4f}'.format(test_loss))  # 打印平均测试损失

4.3.3.评估模型

本节定义了一个名为 run 的函数,它负责驱动VAE模型的训练和评估周期。

  1. 训练周期 (epoch)epoch 是指模型完整地在训练集上学习一次的周期。
  2. 训练轮数指定:使用 range(1, epochs + 1) 来确定训练过程需要进行的总周期数。
  3. 执行训练 (train(epoch)):在每个周期内,调用 train 函数来执行训练任务。此函数通过反向传播算法更新模型参数,目的是减少训练损失。
  4. 执行评估 (test(epoch)): 每个周期训练结束后,调用 test 函数对模型进行评估。此函数计算模型在测试集上的损失表现,并输出结果,但不会对模型参数进行更新。
    通过连续的训练和评估循环,我们可以监测模型性能的变化。理想情况下,我们期望训练损失逐渐降低,测试损失也应保持稳定或降低,这显示了模型正在有效地学习数据特征,同时避免了对训练集的过拟合。
def run():# 定义run函数来运行整个训练和评估过程for epoch in range(1, epochs + 1):  # 从第1个epoch到epochs变量指定的epoch数train(epoch)  # 调用train函数进行训练test(epoch)   # 调用test函数对模型在测试集上进行评估run()  # 调用run函数开始训练和评估过程

4.3.4.生成样本

现在,我们将从训练完成的解码器中生成图像样本。这一过程涉及从先验分布中采样均匀分类变量,并将这些变量输入解码器。我们定义了两个函数:generate_samples用于从VAE模型生成样本,而show_gray_image_grid则用于展示这些生成的图像。
generate_samples函数

  • model.eval():设置模型为评估模式,确保在生成样本时不应用如Dropout或Batch Normalization等只在训练时使用的层。
  • 创建一个形状为 [64, latent_dim, categorical_dim] 的张量,填充为1,然后每个元素乘以 1/categorical_dim,得到均匀的概率分布。这里的64代表生成样本的数量。
  • 使用 OneHotCategorical 分布根据概率 probs 采样,并通过 .cuda() 将概率张量转移到GPU(如果使用CUDA)。
  • model.decode(cat_samples):将one-hot编码的潜在变量通过解码器转换回图像数据。
  • output.view(-1,28,28).detach().cpu().numpy():将解码器的输出调整为28x28像素的图像尺寸,从PyTorch张量分离出来,并转换为NumPy数组,以便于展示和处理。

show_gray_image_grid函数

  • plt.subplots(x, y, figsize=size):根据提供的行数 x、列数 y 和图像大小 size 创建一个图像网格。
  • axs.flatten():将二维轴数组展平为一维,以便于迭代。
  • ax.imshow(np.squeeze(img), cmap='gray'):在每个轴上展示图像,使用灰度色彩映射,np.squeeze(img) 用于去除不必要的单维度。
  • ax.set_axis_off():关闭坐标轴显示,以便更清晰地展示图像。
  • 如果提供了 path 参数,plt.savefig(path) 会将图像保存到文件;否则,plt.show() 会直接展示图像。

最后,通过调用 show_gray_image_grid(samples, 8,8),我们以8行8列的格式展示生成的样本图像。这允许我们直观地评估模型生成图像的质量和多样性。

def generate_samples():# 生成样本的函数model.eval()  # 设置模型为评估模式probs = torch.ones([64, latent_dim, categorical_dim])*(1/categorical_dim)  # 创建一个均匀概率向量cat_samples = OneHotCategorical(probs=probs.cuda()).sample().view(-1, latent_dim*categorical_dim)  # 从均匀分布采样得到分类样本output = model.decode(cat_samples)  # 使用模型的解码器部分生成图像return output.view(-1,28,28).detach().cpu().numpy()  # 将输出转换为numpy数组并返回samples = generate_samples()  # 调用函数生成样本def show_gray_image_grid(imgs, x=2, y=5, size=(8,8), path=None, save=False):# 展示灰度图像网格的函数fig, axs = plt.subplots(x, y, figsize=size)  # 创建图像展示的网格axs = axs.flatten()  # 将轴对象扁平化为一维数组for img, ax in zip(imgs, axs):  # 遍历每个图像和对应的轴对象ax.imshow(np.squeeze(img), cmap='gray')  # 在轴上展示图像,使用灰度色彩映射ax.set_axis_off()  # 不展示坐标轴if save:  # 如果指定了保存路径plt.savefig(path)  # 保存图像到文件else:plt.show()  # 直接展示图像show_gray_image_grid(samples, 8,8)  # 展示生成的样本图像,8行8列的网格布局

4.4.Gumbel直通(Gumbel Straight-Through)

在深度学习中,我们经常遇到需要训练包含离散变量的模型的情况。然而,这些离散变量的放松值可能不适合作为模型输入,或者我们可能需要在优化过程中使用分类/离散输入。为了解决这个问题,我们可以使用一种称为直通估计器(Straight-Through Estimator, STE)的启发式方法。

4.4.1.直通估计器(STE)的概念:

  1. 预激活与采样
    给定预激活值 y y y,我们首先通过非可微的采样操作(例如,从分类或伯努利分布中采样)来计算样本 z z z

  2. 下游函数计算
    使用得到的硬样本 z z z来计算下游函数 f f f

  3. 直通梯度
    在反向传播过程中,我们采用直通梯度,忽略非可微的采样步骤,直接将相对于 z z z的梯度 ∂ z f \partial_z f zf作为相对于 y y y的梯度 ∂ y f \partial_y f yf传递回去。

    ∂ y f : = ∂ z f \partial_y f := \partial_z f yf:=zf

4.4.2.Gumbel-Softmax与直通估计器的结合:

  1. Gumbel-Softmax松弛
    我们使用Gumbel-Softmax技巧来生成放松的离散样本,这些样本在训练过程中是连续的,但在前向传递中可以被视为硬离散值。

  2. 硬采样与直通梯度结合
    给定硬向量 y hard y_{\text{hard}} yhard和软向量 y y y,我们使用以下技巧来结合直通梯度:

    y = ( y hard − y ) .detach() + y y = (y_{\text{hard}} - y) \text{.detach()} + y y=(yhardy).detach()+y

    这样,在前向传递中使用 y hard y_{\text{hard}} yhard,在反向传递中则使用 y y y的梯度。

  3. gumbel_softmax 函数的变体
    我们定义了一个名为 gumbel_softmax 的函数,它是Gumbel-Softmax技巧的一个变体,允许在评估模式下进行硬采样,或者在训练模式下使用直通梯度进行优化。

通过这种方法,我们可以有效地训练包含离散变量的模型,即使这些变量在模型的某些部分需要以确定性的方式进行处理。直通估计器提供了一种在反向传播中处理非可微操作的有效手段,使得模型能够学习到离散变量的分布,同时保持梯度的流动。这种技巧在处理强化学习策略、序列生成模型以及其他需要离散决策的领域中非常有用。

def gumbel_softmax(logits, temperature, evaluate=False, hard=False):# logits: 分类概率的对数几率# temperature: 控制采样随机性的温度参数# evaluate: 是否处于评估模式,如果是,则直接采样# hard: 是否执行直通梯度估计,如果是,则返回硬采样(one-hot编码)if evaluate:# 如果处于评估模式,直接从潜在变量的分布中采样d = OneHotCategorical(logits=logits.view(-1, latent_dim, categorical_dim))return d.sample().view(-1, latent_dim * categorical_dim)y = gumbel_softmax_sample(logits, temperature)  # 应用Gumbel-Softmax采样if hard:# 如果需要硬采样,执行直通梯度估计# 取得每个样本最大值的索引,并将其转换为one-hot向量shape = logits.size()_, k = y.max(-1)  # 取得最大值的索引y_hard = torch.zeros_like(logits)  # 创建一个和logits形状相同的零张量y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)  # 在最大值索引位置插入1# 直通梯度技巧:在前向传播中使用硬采样值,在反向传播中使用放松梯度y = (y_hard - y).detach() + y  # 将y的梯度固定,只传递y_hard的梯度return y.view(-1, latent_dim * categorical_dim)  # 返回重塑为(N, latent_dim * categorical_dim)形状的张量

在这个函数中:

  • logits 是指每个潜在类别的对数几率。
  • temperature 是控制Gumbel-Softmax分布随机性的超参数。
  • evaluate 标志指示是否处于评估模式,在评估模式下,模型会直接从潜在变量的分布中采样,而不是使用Gumbel-Softmax技巧。
  • hard 标志指示是否执行直通梯度估计,在这种情况下,函数会返回一个硬采样的one-hot向量,但在反向传播时使用Gumbel-Softmax采样的梯度。

直通梯度技巧允许我们在前向传播中使用硬采样的离散值,而在反向传播中使用放松的连续梯度,这有助于训练包含离散选择的模型。当然,你也可以使用上述函数在模型定义中使用Gumbel-Straight-Through来训练上述VAE模型,并将hard设置为True

5.总结和展望

5.1.总结

Gumbel-Softmax提供了一种强大的方法,用于在深度学习模型中引入离散性和结构化变量。通过使用这种技巧,我们可以训练通常难以优化的模型,并且能够处理更广泛的数据类型和结构。随着深度学习领域的不断发展,Gumbel-Softmax和其他相关技术将继续在开发新的模型和算法中发挥关键作用。

5.2.未来方向

未来的研究可能会探索如何将Gumbel-Softmax与其他类型的梯度估计技术结合使用,以进一步提高模型的性能和稳定性。此外,研究者可能会探索如何将这些技术应用于强化学习、序列建模和其他需要离散决策的领域。

最后,随着对Gumbel-Softmax和其他相关技术的深入理解,我们可能会发现新的应用场景,这些场景以前由于计算和优化的限制而无法实现。这将为开发更智能、更灵活的AI系统开辟新的可能性。

这篇关于【机器学习】基于Softmax松弛技术的离散数据采样的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1084996

相关文章

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

乐鑫 Matter 技术体验日|快速落地 Matter 产品,引领智能家居生态新发展

随着 Matter 协议的推广和普及,智能家居行业正迎来新的发展机遇,众多厂商纷纷投身于 Matter 产品的研发与验证。然而,开发者普遍面临技术门槛高、认证流程繁琐、生产管理复杂等诸多挑战。  乐鑫信息科技 (688018.SH) 凭借深厚的研发实力与行业洞察力,推出了全面的 Matter 解决方案,包含基于乐鑫 SoC 的 Matter 硬件平台、基于开源 ESP-Matter SDK 的一

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

一份LLM资源清单围观技术大佬的日常;手把手教你在美国搭建「百万卡」AI数据中心;为啥大模型做不好简单的数学计算? | ShowMeAI日报

👀日报&周刊合集 | 🎡ShowMeAI官网 | 🧡 点赞关注评论拜托啦! 1. 为啥大模型做不好简单的数学计算?从大模型高考数学成绩不及格说起 司南评测体系 OpenCompass 选取 7 个大模型 (6 个开源模型+ GPT-4o),组织参与了 2024 年高考「新课标I卷」的语文、数学、英语考试,然后由经验丰富的判卷老师评判得分。 结果如上图所

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

持久层 技术选型如何决策?JPA,Hibernate,ibatis(mybatis)

转自:http://t.51jdy.cn/thread-259-1-1.html 持久层 是一个项目 后台 最重要的部分。他直接 决定了 数据读写的性能,业务编写的复杂度,数据结构(对象结构)等问题。 因此 架构师在考虑 使用那个持久层框架的时候 要考虑清楚。 选择的 标准: 1,项目的场景。 2,团队的技能掌握情况。 3,开发周期(开发效率)。 传统的 业务系统,通常业

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue: