mxnet symbol 解析

2024-04-24 11:08
文章标签 解析 symbol mxnet

本文主要是介绍mxnet symbol 解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

mxnet symbol类定义:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/symbol.py

对于一个symbol,可分为non-grouped和grouped。且symbol具有输出,和输出属性。比如,对于Variable而言,其输入和输出就是它自己。对于c = a+b,c的内部有个_plus0 symbol,对于_plus0这个symbol,它的输入是a,b,输出是_plus0_output。

class Symbol(SymbolBase):"""Symbol is symbolic graph of the mxnet."""# disable dictionary storage, also do not have parent type.# pylint: disable=no-member

其中,Symbol还不是最基础的类,Symbol类继承了SymbolBase这个类。
而SymbolBase这个类实际是在

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/_internal.py

中引用的,通过以下方式引用:

from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class

而SymbolBase的定义是在:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/_ctypes/symbol.py
这里暂时先不管SymbolBase,这应该是是python调用c++接口创建的一个类。

回到Symbol中来,对于mxnet符号式编程而言,定义的任何网络,或者变量,都是symbol类型,所以,了解这个类就显得很重要。

Symbol类中有几类函数:
1、普通函数
2、__xx__ 函数
3、@property 修饰的函数
4、函数名为xx,实际调用op.xx的函数

1、普通函数
attr
根据key返回symbol对应的属性字符串,只对non-grouped symbols起作用。

    def attr(self, key):"""Returns the attribute string for corresponding input key from the symbol.

list_attr
得到symbol的所有属性

    def list_attr(self, recursive=False):"""Gets all attributes from the symbol.

attr_dict
递归的得到symbol和孩子的属性

    def attr_dict(self):"""Recursively gets all attributes from the symbol and its children.Example------->>> a = mx.sym.Variable('a', attr={'a1':'a2'})>>> b = mx.sym.Variable('b', attr={'b1':'b2'})>>> c = a+b>>> c.attr_dict(){'a': {'a1': 'a2'}, 'b': {'b1': 'b2'}}

_set_attr
通过key-value方式,对attr进行设置

    def _set_attr(self, **kwargs):"""Sets an attribute of the symbol.For example. A._set_attr(foo="bar") adds the mapping ``"{foo: bar}"``to the symbol's attribute dictionary.

get_internals
获取symbol的所有内部节点symbol,是一个group类型(包括输入,输出节点symbol)。如果我们想阶段一个network,应该获取它某内部节点的输出,这样才能作为新增加的symbol的输入。

    def get_internals(self):"""Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list ofoutputs of all of the internal nodes.

get_children
获取当前symbol输出节点的inputs

    def get_children(self):"""Gets a new grouped symbol whose output containsinputs to output nodes of the original symbol.

list_arguments
列出当前symbol的所有参数(可以配合call对symbol进行改造)

    def list_arguments(self):"""Lists all the arguments in the symbol.

list_outputs
列出当前smybol的所有输出,如果当前symbol是grouped类型,回遍历输出每一个symbol的输出

    def list_outputs(self):"""Lists all the outputs in the symbol.

list_auxiliary_states
列出symbol中的辅助状态参数,比如BN

    def list_auxiliary_states(self):"""Lists all the auxiliary states in the symbol.Example------->>> a = mx.sym.var('a')>>> b = mx.sym.var('b')>>> c = a + b>>> c.list_auxiliary_states()[]Example of auxiliary states in `BatchNorm`.

list_inputs
列出当前symbol的所有输入参数,和辅助状态,等价于 list_arguments和 list_auxiliary_states

    def list_inputs(self):"""Lists all arguments and auxiliary states of this Symbol.

2、__xx__函数

__repr__
对于gruop symbol,它是没有name属性的,print或者回车,结果就是其内部symbol节点的name
在这里插入图片描述
__iter__(self):
普通的symbol长度都只有1,只有Grouped 的symbol,长度才大于1:return (self[i] for i in range(len(self)))
算数及逻辑运算:
+,-,*, /,%,abs,**, 取负(-x),==,!=,>,>=,<,<=, # 使用时,要注意Broadcasting 是否支持

    def __abs__(self):"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""return self.abs()def __add__(self, other):"""x.__add__(y) <=> x+y其他   

__copy__和__deep_copy__
通过deep_copy,创建一个深拷贝,返回输入对象的一个拷贝,包括它当前所有参数的当前状态,比如weight,bias等
在这里插入图片描述
__call__
表示symbol的实例是一个可调用对象。可以返回一个新的symbol,这个symbol继承了之前symbol的权重啥的,但是和之前的symbol是不同的对象,可以输入参数对symbol进行组合。

    def __call__(self, *args, **kwargs):"""Composes symbol using inputs.Returns-------The resulting symbol."""s = self.__copy__()  #  这里对symbol实例做了一次深拷贝,返回的新的symbols._compose(*args, **kwargs) # 实际调用的_compose函数return s# 对当前的symbol进行编译,返回一个新的symbol,可以指定新symbol的name,其他输入参数必须是symbol类型# 当前symbol的输入参数,可以通过 .list_arguments()获取def _compose(self, *args, **kwargs):"""Composes symbol using inputs.x._compose(y, z) <=> x(y,z)This function mutates the current symbol.Example-------Returns-------The resulting symbol."""name = kwargs.pop('name', None)if name:name = c_str(name)if len(args) != 0 and len(kwargs) != 0:raise TypeError('compose only accept input Symbols \either as positional or keyword arguments, not both')

这里,我改变了b,将其输入参数的x的值变为了tt。
在这里插入图片描述

__getitem__
如果symbol的长度只有1,那么返回的就是它的输出symbol,如果symbol长度>1,可以通过切片访问其输出symbol,返回的也是一个Group symbol。symbol可以分为non-grouped和grouped。
获取内部节点symbol还可以输入str,但输入的str必须属于list_outputs(),

    def __getitem__(self, index):"""x.__getitem__(i) <=> x[i]Returns a sliced view of the input symbol.Parameters----------index : int or strIndexing key"""output_count = len(self)if isinstance(index, py_slice):# 输入切片if isinstance(index, string_types):# 输入字符串# Returning this list of names is expensive. Some symbols may have hundreds of outputsoutput_names = self.list_outputs()idx = Nonefor i, name in enumerate(output_names):if name == index:if idx is not None:raise ValueError('There are multiple outputs with name \"%s\"' % index)idx = iif idx is None:raise ValueError('Cannot find output that matches name \"%s\"' % index)index = idx

symbol.py 除了Symbol这个类之外,还有游离在外的函数:

1def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,init=None, stype=None, **kwargs):"""Creates a symbolic variable with specified name.
# for back compatibility
Variable = var  #  调用 mx.sym.var和mx.sym.Variable 等价2、
def Group(symbols, create_fn=Symbol):"""Creates a symbol that contains a collection of other symbols, grouped together.A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the listare of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all thesymbols in the list are of that type. A type error will be raised if a list of mixedclassic and numpy symbols are provided.Example------->>> a = mx.sym.Variable('a')>>> b = mx.sym.Variable('b')>>> mx.sym.Group([a,b])<Symbol Grouped>Parameters----------symbols : listList of symbols to be grouped.3def load(fname):"""Loads symbol from a JSON file.You also get the benefit being able to directly load/save from cloud storage(S3, HDFS).Returns-------sym : SymbolThe loaded symbol.See Also--------Symbol.save : Used to save symbol into file.
# 输入文件可以是hdfs文件
4、
数学相关函数,输入可为scalar或者是symbol
def pow(base, exp):"""Returns element-wise result of base element raised to powers from exp element.base 和 exp可以是数字或者symbol
# def power(base, exp):  #  实际调用pow
def maximum(left, right):
def minimum(left, right):
def hypot(left, right):  #  返回直角三角形的斜边
def eye(N, M=0, k=0, dtype=None, **kwargs):"""Returns a new symbol of 2-D shpae, filled with ones on the diagonal and zeros elsewhere.  #  返回2D shape的symbol,对角线为1,其余位置为0
def zeros(shape, dtype=None, **kwargs):"""Returns a new symbol of given shape and type, filled with zeros.  # 返回一个shape的全0 symbol
def ones(shape, dtype=None, **kwargs):"""Returns a new symbol of given shape and type, filled with ones.
def full(shape, val, dtype=None, **kwargs):"""Returns a new array of given shape and type, filled with the given value `val`.
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):"""Returns evenly spaced values within a given interval.
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):"""Returns evenly spaced values within a given interval.
def linspace(start, stop, num, endpoint=True, name=None, dtype=None):"""Return evenly spaced numbers within a specified interval.
def histogram(a, bins=10, range=None, **kwargs):"""Compute the histogram of the input data.
def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):"""Split an array into multiple sub-arrays.

这篇关于mxnet symbol 解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/931579

相关文章

线上Java OOM问题定位与解决方案超详细解析

《线上JavaOOM问题定位与解决方案超详细解析》OOM是JVM抛出的错误,表示内存分配失败,:本文主要介绍线上JavaOOM问题定位与解决方案的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一、OOM问题核心认知1.1 OOM定义与技术定位1.2 OOM常见类型及技术特征二、OOM问题定位工具

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

深入解析C++ 中std::map内存管理

《深入解析C++中std::map内存管理》文章详解C++std::map内存管理,指出clear()仅删除元素可能不释放底层内存,建议用swap()与空map交换以彻底释放,针对指针类型需手动de... 目录1️、基本清空std::map2️、使用 swap 彻底释放内存3️、map 中存储指针类型的对象

Java Scanner类解析与实战教程

《JavaScanner类解析与实战教程》JavaScanner类(java.util包)是文本输入解析工具,支持基本类型和字符串读取,基于Readable接口与正则分隔符实现,适用于控制台、文件输... 目录一、核心设计与工作原理1.底层依赖2.解析机制A.核心逻辑基于分隔符(delimiter)和模式匹

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?