高斯混合模型(GMM)的EM算法实现

2024-09-05 10:38

本文主要是介绍高斯混合模型(GMM)的EM算法实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在 聚类算法K-Means, K-Medoids, GMM, Spectral clustering,Ncut一文中我们给出了GMM算法的基本模型与似然函数,在EM算法原理中对EM算法的实现与收敛性证明进行了详细说明。本文主要针对如何用EM算法在混合高斯模型下进行聚类进行代码上的分析说明。

  1. GMM模型:
    每个 GMM 由 K 个 Gaussian 分布组成,每个 Gaussian 称为一个“Component”,这些 Component 线性加成在一起就组成了 GMM 的概率密度函数:

根据上面的式子,如果我们要从 GMM 的分布中随机地取一个点的话,实际上可以分为两步:首先随机地在这 K个Gaussian Component 之中选一个,每个 Component 被选中的概率实际上就是它的系数 pi(k) ,选中了 Component 之后,再单独地考虑从这个 Component 的分布中选取一个点就可以了──这里已经回到了普通的 Gaussian 分布,转化为了已知的问题。

那么如何用 GMM 来做 clustering 呢?其实很简单,现在我们有了数据,假定它们是由 GMM 生成出来的,那么我们只要根据数据推出 GMM 的概率分布来就可以了,然后 GMM 的 K 个 Component 实际上就对应了 K 个 cluster 了。根据数据来推算概率密度通常被称作 density estimation ,特别地,当我们在已知(或假定)了概率密度函数的形式,而要估计其中的参数的过程被称作“参数估计”。

  1. 参数与似然函数:

现在假设我们有 N 个数据点,并假设它们服从某个分布(记作 p(x) ),现在要确定里面的一些参数的值,例如,在 GMM 中,我们就需要确定 影响因子pi(k)、各类均值pMiu(k) 和 各类协方差pSigma(k) 这些参数。 我们的想法是,找到这样一组参数,它所确定的概率分布生成这些给定的数据点的概率最大,而这个概率实际上就等于 ,我们把这个乘积称作似然函数 (Likelihood Function)。通常单个点的概率都很小,许多很小的数字相乘起来在计算机里很容易造成浮点数下溢,因此我们通常会对其取对数,把乘积变成加和 \sum_{i=1}^N \log p(x_i),得到 log-likelihood function 。接下来我们只要将这个函数最大化(通常的做法是求导并令导数等于零,然后解方程),亦即找到这样一组参数值,它让似然函数取得最大值,我们就认为这是最合适的参数,这样就完成了参数估计的过程。

下面让我们来看一看 GMM 的 log-likelihood function :

由于在对数函数里面又有加和,我们没法直接用求导解方程的办法直接求得最大值。为了解决这个问题,我们采取之前从 GMM 中随机选点的办法:分成两步,实际上也就类似于K-means 的两步。

  1. 算法流程:

  2. 估计数据由每个 Component 生成的概率(并不是每个 Component 被选中的概率):对于每个数据 x_i 来说,它由第 k 个 Component 生成的概率为

其中N(xi | μk,Σk)就是后验概率。

  1. 通过极大似然估计可以通过求到令参数=0得到参数pMiu,pSigma的值。具体请见这篇文章第三部分。

其中 N_k = \sum_{i=1}^N \gamma(i, k) ,并且 \pi_k 也顺理成章地可以估计为 N_k/N 。

  1. 重复迭代前面两步,直到似然函数的值收敛为止。

  2. matlab实现GMM聚类代码与解释:

说明:fea为训练样本数据,gnd为样本标号。算法中的思想和上面写的一模一样,在最后的判断accuracy方面,由于聚类和分类不同,只是得到一些 cluster ,而并不知道这些 cluster 应该被打上什么标签,或者说。由于我们的目的是衡量聚类算法的 performance ,因此直接假定这一步能实现最优的对应关系,将每个 cluster 对应到一类上去。一种办法是枚举所有可能的情况并选出最优解,另外,对于这样的问题,我们还可以用 Hungarian algorithm 来求解。具体的Hungarian代码我放在了资源里,调用方法已经写在下面函数中了。

注意:资源里我放的是Kmeans的代码,大家下载的时候只要用bestMap.m等几个文件就好~

  1. gmm.m,最核心的函数,进行模型与参数确定。
    [cpp] view plaincopy
    function varargout = gmm(X, K_or_centroids)
    % ============================================================
    % Expectation-Maximization iteration implementation of
    % Gaussian Mixture Model.
    %
    % PX = GMM(X, K_OR_CENTROIDS)
    % [PX MODEL] = GMM(X, K_OR_CENTROIDS)
    %
    % - X: N-by-D data matrix.
    % - K_OR_CENTROIDS: either K indicating the number of
    % components or a K-by-D matrix indicating the
    % choosing of the initial K centroids.
    %
    % - PX: N-by-K matrix indicating the probability of each
    % component generating each point.
    % - MODEL: a structure containing the parameters for a GMM:
    % MODEL.Miu: a K-by-D matrix.
    % MODEL.Sigma: a D-by-D-by-K matrix.
    % MODEL.Pi: a 1-by-K vector.
    % ============================================================
    % @SourceCode Author: Pluskid (http://blog.pluskid.org)
    % @Appended by : Sophia_qing (http://blog.csdn.net/abcjennifer)

%% Generate Initial Centroids
threshold = 1e-15;
[N, D] = size(X);

if isscalar(K_or_centroids) %if K_or_centroid is a 1*1 number  K = K_or_centroids;  Rn_index = randperm(N); %random index N samples  centroids = X(Rn_index(1:K), :); %generate K random centroid  
else % K_or_centroid is a initial K centroid  K = size(K_or_centroids, 1);   centroids = K_or_centroids;  
end  %% initial values  
[pMiu pPi pSigma] = init_params();  Lprev = -inf; %上一次聚类的误差  %% EM Algorithm  
while true  %% Estimation Step  Px = calc_prob();  % new value for pGamma(N*k), pGamma(i,k) = Xi由第k个Gaussian生成的概率  % 或者说xi中有pGamma(i,k)是由第k个Gaussian生成的  pGamma = Px .* repmat(pPi, N, 1); %分子 = pi(k) * N(xi | pMiu(k), pSigma(k))  pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K); %分母 = pi(j) * N(xi | pMiu(j), pSigma(j))对所有j求和  %% Maximization Step - through Maximize likelihood Estimation  Nk = sum(pGamma, 1); %Nk(1*k) = 第k个高斯生成每个样本的概率的和,所有Nk的总和为N。  % update pMiu  pMiu = diag(1./Nk) * pGamma' * X; %update pMiu through MLE(通过令导数 = 0得到)  pPi = Nk/N;  % update k个 pSigma  for kk = 1:K   Xshift = X-repmat(pMiu(kk, :), N, 1);  pSigma(:, :, kk) = (Xshift' * ...  (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);  end  % check for convergence  L = sum(log(Px*pPi'));  if L-Lprev < threshold  break;  end  Lprev = L;  
end  if nargout == 1  varargout = {Px};  
else  model = [];  model.Miu = pMiu;  model.Sigma = pSigma;  model.Pi = pPi;  varargout = {Px, model};  
end  %% Function Definition  function [pMiu pPi pSigma] = init_params()  pMiu = centroids; %k*D, 即k类的中心点  pPi = zeros(1, K); %k类GMM所占权重(influence factor)  pSigma = zeros(D, D, K); %k类GMM的协方差矩阵,每个是D*D的  % 距离矩阵,计算N*K的矩阵(x-pMiu)^2 = x^2+pMiu^2-2*x*Miu  distmat = repmat(sum(X.*X, 2), 1, K) + ... %x^2, N*1的矩阵replicateK列  repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...%pMiu^2,1*K的矩阵replicateN行  2*X*pMiu';  [~, labels] = min(distmat, [], 2);%Return the minimum from each row  for k=1:K  Xk = X(labels == k, :);  pPi(k) = size(Xk, 1)/N;  pSigma(:, :, k) = cov(Xk);  end  
end  function Px = calc_prob()   %Gaussian posterior probability   %N(x|pMiu,pSigma) = 1/((2pi)^(D/2))*(1/(abs(sigma))^0.5)*exp(-1/2*(x-pMiu)'pSigma^(-1)*(x-pMiu))  Px = zeros(N, K);  for k = 1:K  Xshift = X-repmat(pMiu(k, :), N, 1); %X-pMiu  inv_pSigma = inv(pSigma(:, :, k));  tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);  coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));  Px(:, k) = coef * exp(-0.5*tmp);  end  
end  

end

  1. gmm_accuracy.m调用gmm.m,计算准确率:
    [cpp] view plaincopy
    function [ Accuracy ] = gmm_accuracy( Data_fea, gnd_label, K )
    %Calculate the accuracy Clustered by GMM model

px = gmm(Data_fea,K);
[~, cls_ind] = max(px,[],1); %cls_ind = cluster label
Accuracy = cal_accuracy(cls_ind, gnd_label);

function [acc] = cal_accuracy(gnd,estimate_label)  res = bestMap(gnd,estimate_label);  acc = length(find(gnd == res))/length(gnd);  
end  

end

  1. 主函数调用
    gmm_acc = gmm_accuracy(fea,gnd,N_classes);

写了本文进行总结后自己很受益,也希望大家可以好好YM下上面pluskid的gmm.m,不光是算法,其中的矩阵处理代码也写的很简洁,很值得学习。
另外看了两份东西非常受益,一个是pluskid大牛的《漫谈 Clustering (3): Gaussian Mixture Model》,一个是JerryLead的EM算法详解,大家有兴趣也可以看一下,写的很好。

转自http://blog.csdn.net/abcjennifer/article/details/8198352

这篇关于高斯混合模型(GMM)的EM算法实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1138725

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi