深度学习编译中间件之NNVM(十三)NNVM源代码阅读2

2023-10-29 05:08

本文主要是介绍深度学习编译中间件之NNVM(十三)NNVM源代码阅读2,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考文档

  1. 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1

本系列文档涉及NNVM源代码阅读理解,本篇主要介绍一些NNVM的基础数据结构。

使用的C++命令空间为nnvm

相关代码位于
1. include/nnvm
2. src/core

class Op

代码位于

  • include/nnvm/op.h
  • include/nnvm/op_attr_types.h
  • src/core/op.cc

Op类主要用于记录操作符的一些信息

// 代码只是节选
class NNVM_DLL Op {
public:std::string name; // 操作符名称std::string description; // 操作符详细解释,可用于文档生成std::vector<ParamFieldInfo> arguments; // 带文字描述的参数数组uint32_t num_inputs = 1; // 操作符的输入数据个数uint32_t num_outputs = 1; // 操作符的输出数据个数uint32_t support_level = 10; // 支持优先级,数字越小越优先std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;inline Op& describe(const std::string& descr);inline Op& add_argument(const std::string &name,const std::string &type,const std::string &description);inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);inline Op& set_num_inputs(uint32_t n); inline Op& set_support_level(uint32_t level);inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn);  inline Op& set_num_outputs(uint32_t n);inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn);inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn);template<typename ValueType>inline Op& set_attr(const std::string& attr_name, const ValueType& value,int plevel = 10);Op& add_alias(const std::string& alias); Op& include(const std::string& group_name);static const Op* Get(const std::string& op_name);template<typename ValueType>static const OpMap<ValueType>& GetAttr(const std::string& attr_name);private:template<typename ValueType>friend class OpMap;friend class OpGroup;friend class dmlc::Registry<Op>;uint32_t index_{0}; // 唯一操作符索引,用于OpManager区分Op
};

另外include/nnvm/op_attr_types.h中提供了操作符支持的属性类型定义

// 代码节选
using FInferShape = FInferNodeEntryAttr<TShape>;
using FInferType = FInferNodeEntryAttr<int>;// 得到操作符节点的梯度节点,这个函数用于生成反向传播计算图
using FGradient = std::function<std::vector<NodeEntry>(const NodePtr& nodeptr,const std::vector<NodeEntry>& out_grads)>;
...

class Node

代码位于

  • include/nnvm/node.h
  • src/core/node.cc

Node类用于在一个计算图中表示一个操作

class NNVM_DLL Node {
public:NodeAttrs attrs; // 节点属性std::vector<NodeEntry> inputs; // 节点输入向量 std::vector<NodePtr> control_deps; // 依赖节点,用于控制流依赖any info; // 节点额外信息inline const Op* op() const; // 返回节点包含的操作inline bool is_variable() const; // 判断节点是否是占位变量(即节点内不含操作,节点的作用只是占用)inline uint32_t num_outputs() const; inline uint32_t num_inputs() const;static NodePtr Create(); // 创建一个空节点(静态方法)
};

class Graph

代码位于

  • include/nnvm/graph.h
  • src/core/graph.cc

Graph类用于表示一个计算图,它是一个为了进行优化Pass的中间表示。

class Graph {
public:std::vector<NodeEntry> outputs; // 计算图的输出节点std::unordered_map<std::string, std::shared_ptr<any> > attrs; // 计算图属性集合template<typename T>inline const T& GetAttr(const std::string& attr_name) const;inline bool HasAttr(const std::string& attr_name) const;template<typename T>inline T MoveCopyAttr(const std::string& attr_name);const IndexedGraph& indexed_graph() const; // 获取当前计算图的索引图,如果不存在就按需创建
private:mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
}// 下面介绍Graph的辅助类IndexedGraph/*!* IndexedGraph用于提供索引一个计算图的辅助数据结构* 它将图内部的节点们映射到一个连续整型变量node_id,而且将输出节点映射到一个连续整型变量entry_id。* 这样的方式允许将计算图的内部节点和输出节点存储在一个紧凑的向量结构中,并且可以做到快速存取* 节点的node_id和entry_id是和保存的JSON文件的顺序是一致的*/
class IndexedGraph {
public:/* 表示计算图中的一个数据 */struct NodeEntry {uint32_t node_id;uint32_t index;uint32_t version;};/* 表示计算图中的一个节点 */struct Node {const nnvm::Node* source; // 指向源节点的指针array_view<NodeEntry> inputs; // 节点的输入数据array_view<uint32_t> control_deps;std::weak_ptr<nnvm::Node> weak_ref; // 指向节点的弱引用};inline uint32_t entry_id(uint32_t node_id, uint32_t index); // 获取一个唯一的entry_idinline uint32_t entry_id(const NodeEntry& e); inline const std::vector<uint32_t>& input_nodes(); // 返回argument节点列表
private:friend class Graph;}

class PassFunctionReg

代码位于

  • include/nnvm/pass.h
  • src/core/pass.cc

PassFunctionReg类为DataIterator工厂函数提供注册入口

// PassFunctionReg继承自dmlc::FunctionRegEntryBase,这个类主要用于函数注册。
// PassFunctionReg在FunctionRegEntryBase类注册普通函数的基础上增加和Pass相关的属性和函数struct PassFunctionReg: public dmlc::FunctionRegEntryBase<PassFunctionReg,PassFunction> {bool change_graph{false}; // 标记pass是否会改变计算图的结构std::vector<std::string> graph_attr_dependency; // 记录pass在被应用之前哪些计算图属性必须处于可用std::vector<std::string> graph_attr_targets; // 记录pass在被应用之后将生成哪些计算图属性
}// 下面介绍一些辅助数据结构和函数/*!* \brief 一个PassFunction表示一个针对计算图所做的操作* 这个函数处理一个源计算图,返回一个目标计算图,这两个计算图可能一致也可能不一致* 一个PassFunction可能会改变图结构,也可能会增加图属性*/
typedef std::function<Graph (Graph src)> PassFunction;// 针对输入计算图应用一系列pass
Graph ApplyPasses(Graph src, const std::vector<std::string>& passes);

class Symbol

代码位于

  • include/nnvm/symbolic.h
  • src/core/symbolic.cc

Symbol类是一个帮助类,用于表示计算图中的操作节点。

Symbol类拥有一个利用Group/Functor/Variable这些组件来创建计算图的接口,Symbol类也会被导出到NNVM的Python前端,用于方便进行快速测试和部署。后面将有专门的文档讲解NNVM的Python接口的部分。

// 代码节选
class NNVM_DLL Symbol {
public:std::vector<NodeEntry> outputs;Symbol Copy() const;void Print(std::ostream &os) const;std::vector<NodePtr> ListInputs(ListInputOption option) const;std::vector<std::string> ListInputNames(ListInputOption option) const;std::vector<std::string> ListOutputNames() const;// 创建Symbol/Variable/Group Symbolstatic Symbol CreateFunctor(const Op* op,std::unordered_map<std::string, std::string> attrs);static Symbol CreateFunctor(const NodeAttrs& attrs);static Symbol CreateVariable(const std::string& name);static Symbol CreateGroup(const std::vector<Symbol>& symbols);
}

class Layout

代码位于

  • include/nnvm/layout.h

Layout类用于处理Layout表达式

layout由大写字母、小写字母和数字组成,其中大写字母表示一个维度,大写字母对应的小写字母表示一个split之后的子维度,小写字母之前的数字则表示split块的数量。

例如:NCHW16c

表示:[batch_size, channel, height, width, channel_block], channel_block=16

至此NNVM的基础数据结构就介绍完了,接下来的文档将会具体分析NNVM的重要组件

这篇关于深度学习编译中间件之NNVM(十三)NNVM源代码阅读2的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/298424

相关文章

idea maven编译报错Java heap space的解决方法

《ideamaven编译报错Javaheapspace的解决方法》这篇文章主要为大家详细介绍了ideamaven编译报错Javaheapspace的相关解决方法,文中的示例代码讲解详细,感兴趣的... 目录1.增加 Maven 编译的堆内存2. 增加 IntelliJ IDEA 的堆内存3. 优化 Mave

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

SpringCloud动态配置注解@RefreshScope与@Component的深度解析

《SpringCloud动态配置注解@RefreshScope与@Component的深度解析》在现代微服务架构中,动态配置管理是一个关键需求,本文将为大家介绍SpringCloud中相关的注解@Re... 目录引言1. @RefreshScope 的作用与原理1.1 什么是 @RefreshScope1.

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

Redis 内存淘汰策略深度解析(最新推荐)

《Redis内存淘汰策略深度解析(最新推荐)》本文详细探讨了Redis的内存淘汰策略、实现原理、适用场景及最佳实践,介绍了八种内存淘汰策略,包括noeviction、LRU、LFU、TTL、Rand... 目录一、 内存淘汰策略概述二、内存淘汰策略详解2.1 ​noeviction(不淘汰)​2.2 ​LR

Python与DeepSeek的深度融合实战

《Python与DeepSeek的深度融合实战》Python作为最受欢迎的编程语言之一,以其简洁易读的语法、丰富的库和广泛的应用场景,成为了无数开发者的首选,而DeepSeek,作为人工智能领域的新星... 目录一、python与DeepSeek的结合优势二、模型训练1. 数据准备2. 模型架构与参数设置3

IDEA编译报错“java: 常量字符串过长”的原因及解决方法

《IDEA编译报错“java:常量字符串过长”的原因及解决方法》今天在开发过程中,由于尝试将一个文件的Base64字符串设置为常量,结果导致IDEA编译的时候出现了如下报错java:常量字符串过长,... 目录一、问题描述二、问题原因2.1 理论角度2.2 源码角度三、解决方案解决方案①:StringBui