MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

2023-12-25 08:15

本文主要是介绍MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forgeconda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nnclass GCNLayer(nn.Module):def __init__(self, in_features, out_features, bias=True):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features, bias)def __call__(self, x, adj):x = self.linear(x)return adj @ xclass GCN(nn.Module):def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):super(GCN, self).__init__()layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]self.gcn_layers = [GCNLayer(in_dim, out_dim, bias)for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])]self.dropout = nn.Dropout(p=dropout)def __call__(self, x, adj):for layer in self.gcn_layers[:-1]:x = nn.relu(layer(x, adj))x = self.dropout(x)x = self.gcn_layers[-1](x, adj)return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(x_dim=x.shape[-1],h_dim=args.hidden_dim,out_dim=args.nb_classes,nb_layers=args.nb_layers,dropout=args.dropout,bias=args.bias,)mx.eval(gcn.parameters())optimizer = optim.Adam(learning_rate=args.lr)loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)# Training loopfor epoch in range(args.epochs):# Loss(loss, y_hat), grads = loss_and_grad_fn(gcn, x, adj, y, train_mask, args.weight_decay)optimizer.update(gcn, grads)mx.eval(gcn.parameters(), optimizer.state)# Validationval_loss = loss_fn(y_hat[val_mask], y[val_mask])val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):y_hat = gcn(x, adj)loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):l = mx.mean(nn.losses.cross_entropy(y_hat, y))if weight_decay != 0.0:assert parameters != None, "Model parameters missing for L2 reg."l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()return l + weight_decay * l2_regreturn ldef eval_fn(x, y):return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的金币,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

这篇关于MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/534810

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

Android平台播放RTSP流的几种方案探究(VLC VS ExoPlayer VS SmartPlayer)

技术背景 好多开发者需要遴选Android平台RTSP直播播放器的时候,不知道如何选的好,本文针对常用的方案,做个大概的说明: 1. 使用VLC for Android VLC Media Player(VLC多媒体播放器),最初命名为VideoLAN客户端,是VideoLAN品牌产品,是VideoLAN计划的多媒体播放器。它支持众多音频与视频解码器及文件格式,并支持DVD影音光盘,VCD影

【测试】输入正确用户名和密码,点击登录没有响应的可能性原因

目录 一、前端问题 1. 界面交互问题 2. 输入数据校验问题 二、网络问题 1. 网络连接中断 2. 代理设置问题 三、后端问题 1. 服务器故障 2. 数据库问题 3. 权限问题: 四、其他问题 1. 缓存问题 2. 第三方服务问题 3. 配置问题 一、前端问题 1. 界面交互问题 登录按钮的点击事件未正确绑定,导致点击后无法触发登录操作。 页面可能存在

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss