NCCL集合通信算子DEMO及性能测试

2024-04-13 05:44

本文主要是介绍NCCL集合通信算子DEMO及性能测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NCCL集合通信算子DEMO及性能测试

  • 一.复现代码

以下代码用于测试NCCL算子的性能及正确性

一.复现代码

tee ccl_benchmark.py <<-'EOF'
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
dev_type="cuda"class Timer:def __init__(self,duration):        self.duration=durationdef __enter__(self):dist.barrier()self.beg= datetime.now().timestamp() * 1e6def __exit__(self, exc_type, exc_val, exc_tb):dist.barrier()self.end=datetime.now().timestamp() * 1e6self.duration.append(self.end-self.beg)op_mapping={}
class ccl_benchmark:def __init__(self,func):global op_mapping  op_mapping[func.__name__]=funcself.func=funcdef __call__(self,*args,**kwargs):return self.func(*args,**kwargs)@ccl_benchmark
def all_gather(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起'''duration=[]input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]for _ in range(iters):with Timer(duration):dist.all_gather(gather_list,input_tensor)   output=torch.cat(gather_list,dim=0)gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt=torch.cat(gt,dim=0)return duration,(output.cpu()==gt).all()@ccl_benchmark
def scatter(shape,device,rank,world_size,iters=5):'''将每个rank从scatter_list[rank]取数据到output_tensor'''duration=[]output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)scatter_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*i).to(device) for i in range(world_size)]for _ in range(iters):with Timer(duration):if rank == 0:dist.scatter(output_tensor,scatter_list=scatter_list,src =0)else:dist.scatter(output_tensor,src  = 0)gt=torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*rankreturn duration,(output_tensor.cpu()==gt).all()@ccl_benchmark
def gather(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做'''duration=[]input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]for _ in range(iters):with Timer(duration):if rank == 0:dist.gather(input_tensor,gather_list=gather_list,dst=0)else:dist.gather(input_tensor,dst=0)ret=Trueif rank==0:output=torch.cat(gather_list,dim=0)gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt=torch.cat(gt,dim=0)ret=(output.cpu()==gt).all()return duration,ret@ccl_benchmark
def reduce(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)# input_tensor的内容会被修改,所以放在循环里with Timer(duration):dist.reduce(input_tensor,dst=0,op=dist.ReduceOp.SUM)ret=Trueif rank==0:gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt_=gt[0]       for i in range(1,world_size):gt_=gt_+gt[i]ret=(input_tensor.cpu()==gt_).all()return duration,ret@ccl_benchmark
def broadcast(shape,device,rank,world_size,iters=5):'''将src的rank的数据广播到其它rank'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)with Timer(duration):dist.broadcast(input_tensor,src=0)gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')ret=(input_tensor.cpu()==gt).all()return duration,ret@ccl_benchmark
def p2p(shape,device,rank,world_size,iters=5):'''将src的rank的数据广播到其它rank'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)with Timer(duration):if rank!=0:dist.recv(input_tensor,rank-1)               if rank!=world_size-1:               dist.send(input_tensor,dst=rank+1)   gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')ret=(input_tensor.cpu()==gt).all()return duration,ret@ccl_benchmark
def all_reduce(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)# input_tensor的内容会被修改,所以放在循环里with Timer(duration):dist.all_reduce(input_tensor,op=dist.ReduceOp.SUM)gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt_=gt[0]       for i in range(1,world_size):gt_=gt_+gt[i]ret=(input_tensor.cpu()==gt_).all()return duration,ret@ccl_benchmark
def reduce_scatter(shape,device,rank,world_size,iters=5):''''''duration=[]output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)input_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rank)+chunk_id).to(device) for chunk_id in range(world_size)]for _ in range(iters):with Timer(duration):dist.reduce_scatter(output_tensor,input_list=input_list,op=dist.ReduceOp.SUM)gt_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rk)+rank).to('cpu') for rk in range(world_size)]gt_=gt_list[0]       for i in range(1,world_size):gt_=gt_+gt_list[i]    return duration,(output_tensor.cpu()==gt_).all()def main():dist.init_process_group(backend='nccl')if not torch.distributed.is_initialized():returnparser = argparse.ArgumentParser(description='test')parser.add_argument('--shape', type=str, default="(1024,8192)", help='Number of epochs to train.')parser.add_argument('--iters', type=int, default=5, help='Number of epochs to train.')parser.add_argument('--op', type=str, default="", help='Number of epochs to train.')args = parser.parse_args()global op_mappingif args.op in op_mapping:torch.manual_seed(1)world_size = torch.distributed.get_world_size()rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device(dev_type,local_rank)shape=eval(args.shape)duration,passed=op_mapping[args.op](shape,device,rank,world_size,args.iters)time.sleep(0.1*rank)print("rank:{} op:{} shape:{} iters:{} mean(us):{:.3f} passed:{}".format(rank,args.op,shape,args.iters,np.mean(duration[len(duration)//2:]),passed))dist.destroy_process_group()if __name__=='__main__':main()EOFexport NCCL_DEBUG=error
export NCCL_SOCKET_IFNAME=ens8
export NCCL_IB_DISABLE=1  
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=scatter --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=broadcast --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=p2p --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce_scatter --shape="(1024,4096)" --iters=5

这篇关于NCCL集合通信算子DEMO及性能测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/899274

相关文章

Springboot处理跨域的实现方式(附Demo)

《Springboot处理跨域的实现方式(附Demo)》:本文主要介绍Springboot处理跨域的实现方式(附Demo),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录Springboot处理跨域的方式1. 基本知识2. @CrossOrigin3. 全局跨域设置4.

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、

Python容器类型之列表/字典/元组/集合方式

《Python容器类型之列表/字典/元组/集合方式》:本文主要介绍Python容器类型之列表/字典/元组/集合方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1. 列表(List) - 有序可变序列1.1 基本特性1.2 核心操作1.3 应用场景2. 字典(D

Golang中拼接字符串的6种方式性能对比

《Golang中拼接字符串的6种方式性能对比》golang的string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去,主要有6种拼接方式,下面小编就来为大家详细讲讲吧... 目录拼接方式介绍性能对比测试代码测试结果源码分析golang的string类型是不可修改的,对于拼接字

mysql线上查询之前要性能调优的技巧及示例

《mysql线上查询之前要性能调优的技巧及示例》文章介绍了查询优化的几种方法,包括使用索引、避免不必要的列和行、有效的JOIN策略、子查询和派生表的优化、查询提示和优化器提示等,这些方法可以帮助提高数... 目录避免不必要的列和行使用有效的JOIN策略使用子查询和派生表时要小心使用查询提示和优化器提示其他常

Java集合中的List超详细讲解

《Java集合中的List超详细讲解》本文详细介绍了Java集合框架中的List接口,包括其在集合中的位置、继承体系、常用操作和代码示例,以及不同实现类(如ArrayList、LinkedList和V... 目录一,List的继承体系二,List的常用操作及代码示例1,创建List实例2,增加元素3,访问元

SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程

《SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程》本文详细介绍了如何在虚拟机和宝塔面板中安装RabbitMQ,并使用Java代码实现消息的发送和接收,通过异步通讯,可以优化... 目录一、RabbitMQ安装二、启动RabbitMQ三、javascript编写Java代码1、引入

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

Springboot中分析SQL性能的两种方式详解

《Springboot中分析SQL性能的两种方式详解》文章介绍了SQL性能分析的两种方式:MyBatis-Plus性能分析插件和p6spy框架,MyBatis-Plus插件配置简单,适用于开发和测试环... 目录SQL性能分析的两种方式:功能介绍实现方式:实现步骤:SQL性能分析的两种方式:功能介绍记录