NCCL集合通信算子DEMO及性能测试

2024-04-13 05:44

本文主要是介绍NCCL集合通信算子DEMO及性能测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NCCL集合通信算子DEMO及性能测试

  • 一.复现代码

以下代码用于测试NCCL算子的性能及正确性

一.复现代码

tee ccl_benchmark.py <<-'EOF'
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
dev_type="cuda"class Timer:def __init__(self,duration):        self.duration=durationdef __enter__(self):dist.barrier()self.beg= datetime.now().timestamp() * 1e6def __exit__(self, exc_type, exc_val, exc_tb):dist.barrier()self.end=datetime.now().timestamp() * 1e6self.duration.append(self.end-self.beg)op_mapping={}
class ccl_benchmark:def __init__(self,func):global op_mapping  op_mapping[func.__name__]=funcself.func=funcdef __call__(self,*args,**kwargs):return self.func(*args,**kwargs)@ccl_benchmark
def all_gather(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起'''duration=[]input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]for _ in range(iters):with Timer(duration):dist.all_gather(gather_list,input_tensor)   output=torch.cat(gather_list,dim=0)gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt=torch.cat(gt,dim=0)return duration,(output.cpu()==gt).all()@ccl_benchmark
def scatter(shape,device,rank,world_size,iters=5):'''将每个rank从scatter_list[rank]取数据到output_tensor'''duration=[]output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)scatter_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*i).to(device) for i in range(world_size)]for _ in range(iters):with Timer(duration):if rank == 0:dist.scatter(output_tensor,scatter_list=scatter_list,src =0)else:dist.scatter(output_tensor,src  = 0)gt=torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*rankreturn duration,(output_tensor.cpu()==gt).all()@ccl_benchmark
def gather(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做'''duration=[]input_tensor=(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+rank)).to(device)gather_list=[torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) for _ in range(world_size)]for _ in range(iters):with Timer(duration):if rank == 0:dist.gather(input_tensor,gather_list=gather_list,dst=0)else:dist.gather(input_tensor,dst=0)ret=Trueif rank==0:output=torch.cat(gather_list,dim=0)gt=[torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt=torch.cat(gt,dim=0)ret=(output.cpu()==gt).all()return duration,ret@ccl_benchmark
def reduce(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起 只在批定的rank做'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)# input_tensor的内容会被修改,所以放在循环里with Timer(duration):dist.reduce(input_tensor,dst=0,op=dist.ReduceOp.SUM)ret=Trueif rank==0:gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt_=gt[0]       for i in range(1,world_size):gt_=gt_+gt[i]ret=(input_tensor.cpu()==gt_).all()return duration,ret@ccl_benchmark
def broadcast(shape,device,rank,world_size,iters=5):'''将src的rank的数据广播到其它rank'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)with Timer(duration):dist.broadcast(input_tensor,src=0)gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')ret=(input_tensor.cpu()==gt).all()return duration,ret@ccl_benchmark
def p2p(shape,device,rank,world_size,iters=5):'''将src的rank的数据广播到其它rank'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)with Timer(duration):if rank!=0:dist.recv(input_tensor,rank-1)               if rank!=world_size-1:               dist.send(input_tensor,dst=rank+1)   gt=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+0)).to('cpu')ret=(input_tensor.cpu()==gt).all()return duration,ret@ccl_benchmark
def all_reduce(shape,device,rank,world_size,iters=5):'''将每个rank input_tensor的数据在dim 0维度拼接在一起'''duration=[]   for _ in range(iters):input_tensor=(torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+rank)).to(device)# input_tensor的内容会被修改,所以放在循环里with Timer(duration):dist.all_reduce(input_tensor,op=dist.ReduceOp.SUM)gt=[torch.ones((shape[0],shape[1]),dtype=torch.int64)*(100+i) for i in range(world_size)]gt_=gt[0]       for i in range(1,world_size):gt_=gt_+gt[i]ret=(input_tensor.cpu()==gt_).all()return duration,ret@ccl_benchmark
def reduce_scatter(shape,device,rank,world_size,iters=5):''''''duration=[]output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device)input_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rank)+chunk_id).to(device) for chunk_id in range(world_size)]for _ in range(iters):with Timer(duration):dist.reduce_scatter(output_tensor,input_list=input_list,op=dist.ReduceOp.SUM)gt_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*(100*rk)+rank).to('cpu') for rk in range(world_size)]gt_=gt_list[0]       for i in range(1,world_size):gt_=gt_+gt_list[i]    return duration,(output_tensor.cpu()==gt_).all()def main():dist.init_process_group(backend='nccl')if not torch.distributed.is_initialized():returnparser = argparse.ArgumentParser(description='test')parser.add_argument('--shape', type=str, default="(1024,8192)", help='Number of epochs to train.')parser.add_argument('--iters', type=int, default=5, help='Number of epochs to train.')parser.add_argument('--op', type=str, default="", help='Number of epochs to train.')args = parser.parse_args()global op_mappingif args.op in op_mapping:torch.manual_seed(1)world_size = torch.distributed.get_world_size()rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device(dev_type,local_rank)shape=eval(args.shape)duration,passed=op_mapping[args.op](shape,device,rank,world_size,args.iters)time.sleep(0.1*rank)print("rank:{} op:{} shape:{} iters:{} mean(us):{:.3f} passed:{}".format(rank,args.op,shape,args.iters,np.mean(duration[len(duration)//2:]),passed))dist.destroy_process_group()if __name__=='__main__':main()EOFexport NCCL_DEBUG=error
export NCCL_SOCKET_IFNAME=ens8
export NCCL_IB_DISABLE=1  
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=scatter --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=gather --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=broadcast --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=p2p --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=all_reduce --shape="(1024,4096)" --iters=5
torchrun -m --nnodes=1 --nproc_per_node=4 ccl_benchmark --op=reduce_scatter --shape="(1024,4096)" --iters=5

这篇关于NCCL集合通信算子DEMO及性能测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/899274

相关文章

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

Java集合中的链表与结构详解

《Java集合中的链表与结构详解》链表是一种物理存储结构上非连续的存储结构,数据元素的逻辑顺序的通过链表中的引用链接次序实现,文章对比ArrayList与LinkedList的结构差异,详细讲解了链表... 目录一、链表概念与结构二、当向单链表的实现2.1 准备工作2.2 初始化链表2.3 打印数据、链表长

Java慢查询排查与性能调优完整实战指南

《Java慢查询排查与性能调优完整实战指南》Java调优是一个广泛的话题,它涵盖了代码优化、内存管理、并发处理等多个方面,:本文主要介绍Java慢查询排查与性能调优的相关资料,文中通过代码介绍的非... 目录1. 事故全景:从告警到定位1.1 事故时间线1.2 关键指标异常1.3 排查工具链2. 深度剖析:

深入解析Java NIO在高并发场景下的性能优化实践指南

《深入解析JavaNIO在高并发场景下的性能优化实践指南》随着互联网业务不断演进,对高并发、低延时网络服务的需求日益增长,本文将深入解析JavaNIO在高并发场景下的性能优化方法,希望对大家有所帮助... 目录简介一、技术背景与应用场景二、核心原理深入分析2.1 Selector多路复用2.2 Buffer

基于Python Playwright进行前端性能测试的脚本实现

《基于PythonPlaywright进行前端性能测试的脚本实现》在当今Web应用开发中,性能优化是提升用户体验的关键因素之一,本文将介绍如何使用Playwright构建一个自动化性能测试工具,希望... 目录引言工具概述整体架构核心实现解析1. 浏览器初始化2. 性能数据收集3. 资源分析4. 关键性能指

Python实现MQTT通信的示例代码

《Python实现MQTT通信的示例代码》本文主要介绍了Python实现MQTT通信的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 安装paho-mqtt库‌2. 搭建MQTT代理服务器(Broker)‌‌3. pytho

Zabbix在MySQL性能监控方面的运用及最佳实践记录

《Zabbix在MySQL性能监控方面的运用及最佳实践记录》Zabbix通过自定义脚本和内置模板监控MySQL核心指标(连接、查询、资源、复制),支持自动发现多实例及告警通知,结合可视化仪表盘,可有效... 目录一、核心监控指标及配置1. 关键监控指标示例2. 配置方法二、自动发现与多实例管理1. 实践步骤

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析