pmf-automl源码分析

2024-04-20 23:38
文章标签 分析 源码 automl pmf

本文主要是介绍pmf-automl源码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • arxiv论文(有附录,但是字小)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • NIPS2018论文(字大但是没有附录)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • 代码
    https://github.com/rsheth80/pmf-automl

文章目录

  • 初窥项目文件
  • PMF模型训练
    • 数据切分
    • 初始隐变量
    • 模型的定义与训练
    • D个高斯过程的定义
    • 后验分布协方差矩阵的求解
      • transform_forward与transform_backward函数
      • get_cov函数的顶层设计
      • kernel的RBF
      • kernel的White
      • 求协方差矩阵复盘
    • GP前向函数的返回值的含义

初窥项目文件

用jupyter lab打开all_normalized_accuracy_with_pipelineID.csv
在这里插入图片描述

all_normalized_accuracy_with_pipelineID.zip contains the performance observations from running 42K pipelines on 553 OpenML datasets. The task was classification and the performance metric was balanced accuracy. Unzip prior to running code.

行表示pipeline id,列表示dataset id,元素表示balanced accuracy

在这里插入图片描述
简单查阅了一下pipelines.json,基本只有pcapolynomial两种preprocessor。

PMF模型训练

数据切分

Ytrain, Ytest, Ftrain, Ftest = get_data()
>>> Ytrain.shape
Out[2]: (42000, 464)
>>> Ytest.shape
Out[3]: (42000, 89)
>>> Ftrain.shape
Out[4]: (464, 46)
>>> Ftest.shape
Out[5]: (89, 46)

训练测试集切分,89个数据集作为测试集,464个训练集

初始隐变量

    imp = sklearn.impute.SimpleImputer(missing_values=np.nan, strategy='mean')X = sklearn.decomposition.PCA(Q).fit_transform(imp.fit(Ytrain).transform(Ytrain))
>>> X.shape
Out[7]: (42000, 20)

根据目前的理解,整个训练过程就是根据GP来训练X的隐变量。这个隐变量是用PCA初始化的。

处理训练集的缺失值,并降维为20维(42K个pipelines,数据集从553降为20个隐变量)

论文:the elements of Y Y Y are given by as nonlinear function of the latent variables, y n , d = f d ( x n ) + ϵ y_{n,d}=f_d(x_n)+\epsilon yn,d=fd(xn)+ϵ, where ϵ \epsilon ϵ is independent Gaussian noise.

这里的 Y Y Y指的是整个 42000 × 464 42000\times464 42000×464矩阵,那么 X X X就是pipeline空间的隐变量,这里隐变量维度 Q = 20 Q=20 Q=20 X X X的shape为 42000 × 20 42000\times20 42000×20

模型的定义与训练

模型的顶层定义:

    kernel = kernels.Add(kernels.RBF(Q, lengthscale=None), kernels.White(Q))m = gplvm.GPLVM(Q, X, Ytrain, kernel, N_max=N_max, D_max=batch_size)optimizer = torch.optim.SGD(m.parameters(), lr=lr)m = train(m, optimizer, f_callback=f_callback, f_stop=f_stop)

f_callbackf_stop都是两个local函数

    def f_callback(m, v, it, t):varn_list.append(transform_forward(m.variance).item())logpr_list.append(m().item()/m.D)if it == 1:t_list.append(t)else:t_list.append(t_list[-1] + t)if save_checkpoint and not (it % checkpoint_period):torch.save(m.state_dict(), fn_checkpoint + '_it%d.pt' % it)print('it=%d, f=%g, varn=%g, t: %g'% (it, logpr_list[-1], transform_forward(m.variance), t_list[-1]))
    def f_stop(m, v, it, t):if it >= maxiter-1:print('maxiter (%d) reached' % maxiter)return Truereturn False

看到训练函数train

def train(m, optimizer, f_callback=None, f_stop=None):it = 0while True:try:t = time.time()optimizer.zero_grad()nll = m()nll.backward()optimizer.step()it += 1t = time.time() - tif f_callback is not None:f_callback(m, nll, it, t)# f_stop should not be a substantial portion of total iteration timeif f_stop is not None and f_stop(m, nll, it, t):breakexcept KeyboardInterrupt:breakreturn m

论文公式(5):

N L L d = 1 2 ( N d l o g ( 2 π ) + l o g ∣ C d ∣ + Y c ( d )

这篇关于pmf-automl源码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921629

相关文章

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

MOLE 2.5 分析分子通道和孔隙

软件介绍 生物大分子通道和孔隙在生物学中发挥着重要作用,例如在分子识别和酶底物特异性方面。 我们介绍了一种名为 MOLE 2.5 的高级软件工具,该工具旨在分析分子通道和孔隙。 与其他可用软件工具的基准测试表明,MOLE 2.5 相比更快、更强大、功能更丰富。作为一项新功能,MOLE 2.5 可以估算已识别通道的物理化学性质。 软件下载 https://pan.quark.cn/s/57

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

衡石分析平台使用手册-单机安装及启动

单机安装及启动​ 本文讲述如何在单机环境下进行 HENGSHI SENSE 安装的操作过程。 在安装前请确认网络环境,如果是隔离环境,无法连接互联网时,请先按照 离线环境安装依赖的指导进行依赖包的安装,然后按照本文的指导继续操作。如果网络环境可以连接互联网,请直接按照本文的指导进行安装。 准备工作​ 请参考安装环境文档准备安装环境。 配置用户与安装目录。 在操作前请检查您是否有 sud

线性因子模型 - 独立分量分析(ICA)篇

序言 线性因子模型是数据分析与机器学习中的一类重要模型,它们通过引入潜变量( latent variables \text{latent variables} latent variables)来更好地表征数据。其中,独立分量分析( ICA \text{ICA} ICA)作为线性因子模型的一种,以其独特的视角和广泛的应用领域而备受关注。 ICA \text{ICA} ICA旨在将观察到的复杂信号

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。