稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】

2024-05-06 17:52

本文主要是介绍稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 简介
      • 1. 定义和目标
      • 2. 协方差函数与引入点
      • 3. 变分分布
      • 4. 近似后验参数的计算
      • 5. 计算具体步骤
      • 6. 优势与应用(时间复杂度)
  • 应用案例
      • 1. 初始化参数
      • 2. 计算核矩阵
      • 3. 计算优化变分参数
      • 4. 预测新数据点
      • 5. 结果展示
    • 符号和参数说明
  • python代码

简介

稀疏变分高斯过程(Sparse Variational Gaussian Processes, SVGP)是一种高效的高斯过程(GP)近似方法,它使用一组称为引入点的固定数据点来近似整个数据集。这种方法大大减少了高斯过程模型的计算复杂度,使其能够适用于大数据集。下面是SVGP的详细数学过程。

1. 定义和目标

在标准高斯过程中,给定数据集 { ( x i , y i ) } i = 1 N \{(\mathbf{x}_i, y_i)\}_{i=1}^N {(xi,yi)}i=1N,目标是学习一个映射 f f f ,其中 f ∼ G P ( m , k ) f \sim \mathcal{GP}(m, k) fGP(m,k) m m m 是均值函数, k k k 是协方差函数。SVGP的目标是使用一组较小的引入点 Z = { z i } i = 1 M \mathbf{Z} = \{\mathbf{z}_i\}_{i=1}^M Z={zi}i=1M (其中 M ≪ N M \ll N MN)来近似这个映射。

2. 协方差函数与引入点

引入点 Z \mathbf{Z} Z 被用于构建一个近似的协方差矩阵 K M M \mathbf{K}_{MM} KMM,其中包含引入点之间的协方差。实际的观测点 X \mathbf{X} X 与引入点之间的协方差表示为 K N M \mathbf{K}_{NM} KNM

3. 变分分布

在SVGP中,我们设定变分分布 q ( f ) q(f) q(f) 来近似真实的后验分布。变分分布假设形式为:
q ( f ) = ∫ p ( f ∣ u ) q ( u ) d u q(\mathbf{f}) = \int p(\mathbf{f} | \mathbf{u}) q(\mathbf{u}) \, d\mathbf{u} q(f)=p(fu)q(u)du
其中 u \mathbf{u} u 是在引入点上的函数值, q ( u ) = N ( u ∣ m , S ) q(\mathbf{u}) = \mathcal{N}(\mathbf{u} | \mathbf{m}, \mathbf{S}) q(u)=N(um,S) 是定义在引入点上的高斯分布,具有均值 m \mathbf{m} m 和协方差矩阵 S \mathbf{S} S

4. 近似后验参数的计算

变分参数 m \mathbf{m} m S \mathbf{S} S 通过最小化KL散度(Kullback-Leibler divergence)来学习,这是变分推断中的常用方法。这要求我们计算如下的期望对数似然和KL散度:

ELBO = E q ( f ) [ log ⁡ p ( y ∣ f ) ] − KL ( q ( u ) ∥ p ( u ) ) \text{ELBO} = \mathbb{E}_{q(\mathbf{f})}[\log p(\mathbf{y}|\mathbf{f})] - \text{KL}(q(\mathbf{u}) \| p(\mathbf{u})) ELBO=Eq(f)[logp(yf)]KL(q(u)p(u))

其中,第一项是在变分分布下数据的对数似然的期望,第二项是变分分布和先验分布之间的KL散度。

5. 计算具体步骤

  • 计算协方差矩阵 K M M \mathbf{K}_{MM} KMM, K N M \mathbf{K}_{NM} KNM K N N \mathbf{K}_{NN} KNN
  • 变分分布更新:通过优化ELBO来学习变分参数 m \mathbf{m} m S \mathbf{S} S
  • 后验均值和协方差的更新:在测试点 X ∗ \mathbf{X}_* X 上的后验均值和方差可以通过变分参数和核矩阵计算得到。

6. 优势与应用(时间复杂度)

SVGP减少了与 N N N 成二次方的计算复杂度,变为与 M M M 成二次方的计算复杂度,其中 M M M 通常远小于 N N N。这使得SVGP可以应用于大规模数据集的概率建模和推断。

应用案例

让我们通过一个具体的数值例子来解释稀疏变分高斯过程(SVGP)的操作和计算。假设我们有一个简单的一维回归任务,数据集由以下观测点构成:

  • 训练数据点 X X X X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]
  • 对应的目标值 y y y y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]

我们的目标是使用SVGP来拟合这些数据。我们设定使用两个引入点( M = 2 M = 2 M=2),位置分布在输入空间内:

  • 引入点 Z Z Z Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]

假设使用的核函数是平方指数核,其长度尺度设定为 1.0,输出方差也设定为 1.0。

1. 初始化参数

引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数随机初始化:
m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]

2. 计算核矩阵

设定噪声水平 σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1

  • K M M \mathbf{K}_{MM} KMM:核矩阵在引入点之间:
    K M M = [ 1.0 e − 4.5 / 2 e − 4.5 / 2 1.0 ] \mathbf{K}_{MM} = \begin{bmatrix} 1.0 & e^{-4.5/2} \\ e^{-4.5/2} & 1.0 \end{bmatrix} KMM=[1.0e4.5/2e4.5/21.0]
  • K N M \mathbf{K}_{NM} KNM:核矩阵在观测点和引入点之间:
    K N M = [ e − 0.25 / 2 e − 12.25 / 2 e − 0.25 / 2 e − 6.25 / 2 e − 1.0 / 2 e − 2.25 / 2 e − 4.5 / 2 e − 0.25 / 2 e − 9.0 / 2 e − 0.25 / 2 ] \mathbf{K}_{NM} = \begin{bmatrix} e^{-0.25/2} & e^{-12.25/2} \\ e^{-0.25/2} & e^{-6.25/2} \\ e^{-1.0/2} & e^{-2.25/2} \\ e^{-4.5/2} & e^{-0.25/2} \\ e^{-9.0/2} & e^{-0.25/2} \end{bmatrix} KNM= e0.25/2e0.25/2e1.0/2e4.5/2e9.0/2e12.25/2e6.25/2e2.25/2e0.25/2e0.25/2

3. 计算优化变分参数

在这一步,我们利用变分推断来优化引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数。假设我们使用期望传播(EP)或者自然梯度下降来优化。

(1). 计算引入点的后验分布参数:

  • 使用的核矩阵 K M M \mathbf{K}_{MM} KMM K N M \mathbf{K}_{NM} KNM 已经给出。
  • 计算精度矩阵(逆协方差矩阵) Λ \Lambda Λ
    Λ = K M M − 1 + K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) K N M \Lambda = \mathbf{K}_{MM}^{-1} + \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{K}_{NM} Λ=KMM1+KNMdiag(σ2+Var[fn]1)KNM
  • 其中, Var [ f n ] \text{Var}[\mathbf{f}_n] Var[fn] 是每个数据点的方差,可以假设初始为零。
  • 更新均值 m \mathbf{m} m
    m = Λ − 1 K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) y \mathbf{m} = \Lambda^{-1} \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{y} m=Λ1KNMdiag(σ2+Var[fn]1)y

(2). 优化变分参数:

  • 根据变分推断框架,我们最小化KL散度。这通常涉及到迭代更新 m \mathbf{m} m S \mathbf{S} S 直到收敛:
    S = Λ − 1 \mathbf{S} = \Lambda^{-1} S=Λ1

4. 预测新数据点

给定新的输入位置 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0],我们使用更新后的变分参数进行预测:

(1). 计算新数据点与引入点之间的核矩阵 K ∗ M \mathbf{K}_{*M} KM
K ∗ M = [ e − 1.0 / 2 e − 9.0 / 2 e − 4.0 / 2 e − 1.0 / 2 ] \mathbf{K}_{*M} = \begin{bmatrix} e^{-1.0/2} & e^{-9.0/2} \\ e^{-4.0/2} & e^{-1.0/2} \end{bmatrix} KM=[e1.0/2e4.0/2e9.0/2e1.0/2]

(2). 计算新数据点自身的核矩阵 K ∗ ∗ \mathbf{K}_{**} K∗∗
K ∗ ∗ = [ 1.0 e − 1.0 / 2 e − 1.0 / 2 1.0 ] \mathbf{K}_{**} = \begin{bmatrix} 1.0 & e^{-1.0/2} \\ e^{-1.0/2} & 1.0 \end{bmatrix} K∗∗=[1.0e1.0/2e1.0/21.0]

(3). 使用变分后验公式计算预测均值和方差:

  • 均值:
    μ ∗ = K ∗ M K M M − 1 m \mu_* = \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} \mathbf{m} μ=KMKMM1m
  • 方差:
    Σ ∗ = K ∗ ∗ − K ∗ M K M M − 1 ( K M M − S ) K M M − 1 K M ∗ \Sigma_* = \mathbf{K}_{**} - \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} (\mathbf{K}_{MM} - \mathbf{S}) \mathbf{K}_{MM}^{-1} \mathbf{K}_{M*} Σ=K∗∗KMKMM1(KMMS)KMM1KM
  • 这里 K M ∗ \mathbf{K}_{M*} KM K ∗ M \mathbf{K}_{*M} KM 的转置。

5. 结果展示

这种方法给出了在新数据点 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0] 处的预测分布,包括均值和方差,这些预测可以用于后续的分析或决策制定。

假设优化后,参数更新为:
m = [ 1.5 , 2.0 ] \mathbf{m} = [1.5, 2.0] m=[1.5,2.0]
S = [ 0.5 0.1 0.1 0.5 ] \mathbf{S} = \begin{bmatrix} 0.5 & 0.1 \\ 0.1 & 0.5 \end{bmatrix} S=[0.50.10.10.5]

  • 注意:在实际应用中,这些计算会通过自动化软件工具包如GPflow或GPyTorch来完成。

让我们详细解释上述例子中使用的每个符号和参数:

符号和参数说明

  1. X X X(训练数据点):

    • X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]:一维输入空间中的训练数据点。
  2. y y y(目标值):

    • y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]:对应于输入点 X X X 的目标输出。
  3. Z Z Z(引入点):

    • Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]:在输入空间中人为设置的参考点,用于近似全空间的高斯过程。
  4. M M M(引入点数量):

    • M = 2 M = 2 M=2:引入点的总数。
  5. 核函数:

    • 平方指数核(Squared Exponential Kernel),核函数的选择直接影响模型的平滑性和灵活性。
  6. 长度尺度(Length Scale):

    • 控制核函数的宽度,这里假设为 1.0。
  7. 输出方差(Output Variance):

    • 核函数的高度,这里假设为 1.0。
  8. σ 2 \sigma^2 σ2(观测噪声方差):

    • σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1:数据的噪声水平,影响模型对数据的敏感程度。
  9. m \mathbf{m} m(引入点的均值参数):

    • 初始设置为 m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
  10. S \mathbf{S} S(引入点的协方差参数):

  • 初始设置为 S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]
  1. K M M \mathbf{K}_{MM} KMM(引入点间的核矩阵):
  • 计算所有引入点之间的核值。
  1. K N M \mathbf{K}_{NM} KNM(训练点与引入点间的核矩阵):
  • 计算训练点与引入点之间的核值。
  1. K ∗ M \mathbf{K}_{*M} KM(新数据点与引入点间的核矩阵):
  • 计算新数据点与引入点之间的核值。
  1. K ∗ ∗ \mathbf{K}_{**} K∗∗(新数据点自身的核矩阵):
  • 计算新数据点之间的核值。
  1. 变分参数(Variational Parameters):
  • m \mathbf{m} m S \mathbf{S} S 是在变分推断中优化的参数,用于近似后验分布。

python代码

本代码包含参数更新和预测:

import torch
import torch.nn as nn
import numpy as np
import mathclass SVGP(nn.Module):def __init__(self, inducing_points, kernel_scale=1.0, jitter=1e-6):super(SVGP, self).__init__()self.inducing_points = nn.Parameter(torch.tensor(inducing_points, dtype=torch.float32))self.kernel_scale = kernel_scaleself.jitter = jitterself.variational_mean = nn.Parameter(torch.zeros(self.inducing_points.shape[0]))self.variational_cov = nn.Parameter(torch.eye(self.inducing_points.shape[0]))def rbf_kernel(self, X, Y):dist = torch.cdist(X, Y)**2return torch.exp(-0.5 / self.kernel_scale * dist)def forward(self, X, y=None):# Calculate kernel matricesK_mm = self.rbf_kernel(self.inducing_points, self.inducing_points) + self.jitter * torch.eye(self.inducing_points.shape[0])K_nm = self.rbf_kernel(X, self.inducing_points)K_mn = K_nm.TK_nn = self.rbf_kernel(X, X)# Compute the inverse of K_mmK_mm_inv = torch.inverse(K_mm)# If training mode, optimize the variational parametersif y is not None:noise = 0.1  # Fixed noise for simplicityA = torch.mm(torch.mm(K_nm, K_mm_inv), K_mn) + torch.eye(X.shape[0]) * noiseB = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()# Compute the variational lower bound and gradients for optimization# Placeholder for actual ELBO calculationloss = torch.mean((y - B)**2) + torch.trace(A)return loss# If not training mode, do the predictionelse:# Predictive mean and variancemu_star = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()v_star = K_nn - torch.mm(torch.mm(K_nm, K_mm_inv), K_mn)return mu_star, v_stardef train_model(self, X_train, y_train, learning_rate=0.01, epochs=100):optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)self.train()for epoch in range(epochs):optimizer.zero_grad()loss = self.forward(X_train, y_train)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')def predict(self, X_test):self.eval()mu_star, v_star = self.forward(X_test)return mu_star, v_star.diag()# Example usage
X_train = torch.tensor([[0.5], [1.5], [2.5], [3.5], [4.5]])
y_train = torch.tensor([1.0, 2.0, 3.0, 2.5, 1.5])
inducing_points = torch.tensor([[1.0], [4.0]])model = SVGP(inducing_points=inducing_points)
model.train_model(X_train, y_train)
mu_star, v_star = model.predict(torch.tensor([[2.0], [3.0]]))
print(f'Mean predictions: {mu_star}, Variances: {v_star}')

这篇关于稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/964965

相关文章

Flutter监听当前页面可见与隐藏状态的代码详解

《Flutter监听当前页面可见与隐藏状态的代码详解》文章介绍了如何在Flutter中使用路由观察者来监听应用进入前台或后台状态以及页面的显示和隐藏,并通过代码示例讲解的非常详细,需要的朋友可以参考下... flutter 可以监听 app 进入前台还是后台状态,也可以监听当http://www.cppcn

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

IDEA与JDK、Maven安装配置完整步骤解析

《IDEA与JDK、Maven安装配置完整步骤解析》:本文主要介绍如何安装和配置IDE(IntelliJIDEA),包括IDE的安装步骤、JDK的下载与配置、Maven的安装与配置,以及如何在I... 目录1. IDE安装步骤2.配置操作步骤3. JDK配置下载JDK配置JDK环境变量4. Maven配置下

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

Java8需要知道的4个函数式接口简单教程

《Java8需要知道的4个函数式接口简单教程》:本文主要介绍Java8中引入的函数式接口,包括Consumer、Supplier、Predicate和Function,以及它们的用法和特点,文中... 目录什么是函数是接口?Consumer接口定义核心特点注意事项常见用法1.基本用法2.结合andThen链

SpringBoot集成图片验证码框架easy-captcha的详细过程

《SpringBoot集成图片验证码框架easy-captcha的详细过程》本文介绍了如何将Easy-Captcha框架集成到SpringBoot项目中,实现图片验证码功能,Easy-Captcha是... 目录SpringBoot集成图片验证码框架easy-captcha一、引言二、依赖三、代码1. Ea

linux环境openssl、openssh升级流程

《linux环境openssl、openssh升级流程》该文章详细介绍了在Ubuntu22.04系统上升级OpenSSL和OpenSSH的方法,首先,升级OpenSSL的步骤包括下载最新版本、安装编译... 目录一.升级openssl1.官网下载最新版openssl2.安装编译环境3.下载后解压安装4.备份

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

Python中配置文件的全面解析与使用

《Python中配置文件的全面解析与使用》在Python开发中,配置文件扮演着举足轻重的角色,它们允许开发者在不修改代码的情况下调整应用程序的行为,下面我们就来看看常见Python配置文件格式的使用吧... 目录一、INI配置文件二、YAML配置文件三、jsON配置文件四、TOML配置文件五、XML配置文件