基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

2024-05-10 22:04

本文主要是介绍基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

  • 一.局部效果图
  • 二.运行训练过程,拦截算子,生成调用关系信息
  • 三.可视化,生成SVG图像

想知道Megatron-DeepSpeed训练过程中各模块之间的调用关系。torch_dispatch机制可以拦截算子,inspect又能获取到调用栈(文件,类名,函数,行号).基于这些信息可以生成调用关系,最后用graphviz生成SVG图像。该思路也可以用来画其它pytorch工程的调用关系图

1.为了减少图像宽度,一行显示一级文件路径
2.没有显示具体的ATen算子。因为边太乱

一.局部效果图

在这里插入图片描述

二.运行训练过程,拦截算子,生成调用关系信息

# 前面构建模型的代码省略
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
from dataclasses import dataclass
from typing import Any
import pickle@dataclass
class _ProfilerState:cls: Anyobject: Any = Noneclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parent        self.global_index=0        self.nodes=set()self.edges=set()def __del__(self):self.rank = torch.distributed.get_rank()graph={"nodes":self.nodes,"edges":self.edges}with open(f"call_graph_{self.rank}.pkl","wb") as f:pickle.dump(graph,f)def is_keep(self,node):# if node.function.find("wrapper")>=0:#     return False# if node.function.find("_call_impl")>=0:#     return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):self.global_index+=1self.rank = torch.distributed.get_rank() func_packet = func._overloadpacket       if kwargs is None:kwargs = {}if self.rank==0:stacks=[i for i in inspect.stack() if self.is_keep(i)]stacks_sz=len(stacks)for idx in range(stacks_sz-1,1,-1):if "self" in stacks[idx].frame.f_locals:class_name = stacks[idx].frame.f_locals["self"].__class__.__name__else:class_name=""this_node=f"{stacks[idx].filename}:[{class_name}]:{stacks[idx].function}"if "self" in stacks[idx-1].frame.f_locals:class_name = stacks[idx-1].frame.f_locals["self"].__class__.__name__else:class_name=""                                    next_node=f"{stacks[idx-1].filename}:[{class_name}]:{stacks[idx-1].function}"self.nodes.add(this_node)self.nodes.add(next_node)self.edges.add(f"{this_node}->{next_node}")# if stacks_sz>1:#     if "self" in stacks[1].frame.f_locals:#         class_name = stacks[1].frame.f_locals["self"].__class__.__name__#     else:#         class_name=""                #     this_node=f"{stacks[1].filename}:[{class_name}]:{stacks[1].function}"#     next_node=f"{func_packet.__name__}"#     self.nodes.add(this_node)   #     self.nodes.add(next_node)            #     self.edges.add(f"{this_node}->{next_node}")ret= func(*args, **kwargs)return retclass TorchDumper:_CURRENT_Dumper = Nonedef __init__(self,schedule: Any):self.p= _ProfilerState(schedule)def __enter__(self):assert TorchDumper._CURRENT_Dumper is NoneTorchDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.object  #序列化保存def main():with TorchDumper(TorchDumpDispatchMode):#训练入口pretrain(train_valid_test_datasets_provider,model_provider,forward_step,extra_args_provider=llama_argument_handler,args_defaults={"tokenizer_type": "GPT2BPETokenizer"},)if __name__ == "__main__":main()

三.可视化,生成SVG图像

# coding=utf-8import os
from graphviz import Digraph,Graph
import pickle
import random
from distinctipy import distinctipydef generate_colors(N):'''生成N种有区别度的颜色'''result=[]for red, green, blue in distinctipy.get_colors(N):result.append("#{:02X}{:02X}{:02X}".format(int(red*255), int(green*255), int(blue*255)))return resultdef replace_name(name):'''修改节点名字(缩短,添加换行)'''if name.find("__torch_dispatch__")>=0:return Nonename=name.replace("/home/user/Megatron-DeepSpeed/","")name=name.replace("/home/anaconda3/envs/dev/lib/python3.10/site-packages/","")name=name.replace("/home/user/deepspeed/","")name=name.replace("/home/anaconda3/envs/dev/","")name=name.replace("/",r"\n")name=name.replace(":",r"\n")return name# 1.加载HOOK生成的调用关系文件
rank=0
with open(f"call_graph_{rank}.pkl","rb") as f:data=pickle.load(f)# 2.构建图,设置属性
dot = Digraph()
dot.node_attr = {"shape": "plaintext"}
dot.attr('graph', layout='dot')
dot.graph_attr.update(sep='4.0', ratio='compress')node_desc_id_map={}  #节点名与描述的关系映射表
src_node_color={}    #节点颜色映射表(同一个节点输出的边颜色一样)colors = generate_colors(10)
colors_sz=len(colors)fontsize="16"        #节点字体大小
penwidth="2.0"       #边宽度# 3.添加节点
for idx,v in enumerate(data["nodes"]):v=replace_name(v)if v is None:continuenode_desc_id_map[v]=f"{idx}"if v.find("megatron")>=0:dot.node(f"{idx}",v,style='filled',color='#73FBFD',fontsize=fontsize)elif v.find("deepspeed")>=0:dot.node(f"{idx}",v,style='filled',color='#FA8D89',fontsize=fontsize)else:dot.node(f"{idx}",v,style='filled',color='#C0C0C0',fontsize=fontsize)src_node_color[v]=colors[idx%colors_sz]# 4.添加边
for edge in data["edges"]:from_node,to_node=edge.split("->")from_node=replace_name(from_node)to_node=replace_name(to_node)if all([from_node,to_node]):color=src_node_color[from_node]dot.edge(node_desc_id_map[from_node], node_desc_id_map[to_node],color=color,penwidth=penwidth)# 5.保存SVG
save_path='megatron_deepspeed_callgraph'
dot.render(save_path,format='svg', view=False)# 6.修改背景色为灰色
import xml.etree.ElementTree as ET
svg_tree = ET.parse(f'{save_path}.svg')
root = svg_tree.getroot()
element = root.find(".//{http://www.w3.org/2000/svg}polygon")
element.set('fill', 'gray')
svg_tree.write(f'{save_path}.svg')

这篇关于基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/977679

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

如何在页面调用utility bar并传递参数至lwc组件

1.在app的utility item中添加lwc组件: 2.调用utility bar api的方式有两种: 方法一,通过lwc调用: import {LightningElement,api ,wire } from 'lwc';import { publish, MessageContext } from 'lightning/messageService';import Ca

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

hdu 1102 uva 10397(最小生成树prim)

hdu 1102: 题意: 给一个邻接矩阵,给一些村庄间已经修的路,问最小生成树。 解析: 把已经修的路的权值改为0,套个prim()。 注意prim 最外层循坏为n-1。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstri

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL