【torch杂记】torch.nn.init.kaiming_normal_

2024-03-05 17:58
文章标签 init torch nn 杂记 normal kaiming

本文主要是介绍【torch杂记】torch.nn.init.kaiming_normal_,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.nn.init.kaiming_normal_

文章目录

        • torch.nn.init.kaiming_normal_
          • 参考
          • 源码

参考
  • torch.nn.init.kaiming_normal_
  • python中的numel()函数
源码
  • 这个函数就是实现这个公式

    • std = gain fan_mode \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} std=fan_mode gain
  • def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):r"""kaiming正态分布"""fan = _calculate_correct_fan(tensor, mode)gain = calculate_gain(nonlinearity, a)std = gain / math.sqrt(fan)with torch.no_grad():# 这句是返回指定区间内随机生成的正太分布的值的 return tensor.normal_(0, std)
    
  • _calculate_correct_fan(tensor, mode)是算出input和output feature map的元素总数,源码为:

    • def _calculate_correct_fan(tensor, mode):mode = mode.lower()valid_modes = ['fan_in', 'fan_out']if mode not in valid_modes:raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))# 这里是fmap的大小fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)# 根据mode选择返回数据return fan_in if mode == 'fan_in' else fan_out
      
    • def _calculate_fan_in_and_fan_out(tensor):dimensions = tensor.dim()if dimensions < 2:raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")# 这里相当于输出了前两维的sizenum_input_fmaps = tensor.size(1)num_output_fmaps = tensor.size(0)# 这里相当于计算了后两维的元素总和receptive_field_size = 1if tensor.dim() > 2:# numel()的作用就是计算元素的个数receptive_field_size = tensor[0][0].numel()# 然后算出in/out的fmap的大小fan_in = num_input_fmaps * receptive_field_sizefan_out = num_output_fmaps * receptive_field_sizereturn fan_in, fan_out
      
    • 上面源码可以用下列例子解释:

      • 比如有tensor.size()=[3,48,11,11],前两者分布是output_channel和input_channel
      • fan_in =48*11*11=5808
      • fan_out=3*11*11=363
      • 然后根据mode匹配决定return哪个
  • 感谢评论区大佬指出错误,num_input_fmaps是用的size(1),num_output_fmaps用的size(0)

  • calculate_gain(nonlinearity, a)如果选的是relu,那么return math.sqrt(2.0),即根号2,下面是源码,其中注释给出了详细的gain值

    • def calculate_gain(nonlinearity, param=None):r"""Return the recommended gain value for the given nonlinearity function.The values are as follows:================= ====================================================nonlinearity      gain================= ====================================================Linear / Identity :math:`1`Conv{1,2,3}D      :math:`1`Sigmoid           :math:`1`Tanh              :math:`\frac{5}{3}`ReLU              :math:`\sqrt{2}`Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`SELU              :math:`\frac{3}{4}`================= ====================================================Args:nonlinearity: the non-linear function (`nn.functional` name)param: optional parameter for the non-linear functionExamples:>>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2"""linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']if nonlinearity in linear_fns or nonlinearity == 'sigmoid':return 1elif nonlinearity == 'tanh':return 5.0 / 3elif nonlinearity == 'relu':return math.sqrt(2.0)elif nonlinearity == 'leaky_relu':if param is None:negative_slope = 0.01elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):# True/False are instances of int, hence check abovenegative_slope = paramelse:raise ValueError("negative_slope {} not a valid number".format(param))return math.sqrt(2.0 / (1 + negative_slope ** 2))elif nonlinearity == 'selu':return 3.0 / 4  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)else:raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
      
  • tensor.normal_(0, std)

    • 大意是返回一个张量,张量里面的随机数是从相互独立的正态分布中随机生成的。
    • 0为均值,std为标准差

这篇关于【torch杂记】torch.nn.init.kaiming_normal_的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/777344

相关文章

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

4.15 版本内核调用 init_timer()函数出错

linux/include/linux/timer.h4.15 之前版本struct timer_list {14 /*15 * All fields that change during normal runtime grouped to the16 * same cacheline17 */18 struct hl

Python方法:__init__,__new__,__class__的使用详解

转自:https://blog.csdn.net/qq_26442553/article/details/82464682 因为python中所有类默认继承object类。而object类提供了了很多原始的内建属性和方法,所以用户自定义的类在Python中也会继承这些内建属性。可以使用dir()函数可以查看,虽然python提供了很多内建属性但实际开发中常用的不多。而很多系统提供的内建属性实际

生活杂记1

生命中,总有一些事需要你一生去治愈,我把这些杂记写出来,写完了就不再想了,太内耗了…hahaha~ 因为嘴馋,小时候经常去老姑家,她家有各类零食及平时很少吃的“山珍海味”。去的次数多了,就和她家附近的邻居小孩也混的熟络了。再后来上了高中去的就少了,当年七中统招线521自费线491。我刚好压自费线,举全家之力花了15000读了七中,也没争气,后面高考也一塌糊涂。高二那会,一次去老姑家做客,经

【杂记】裂脑人实验和语言模型幻觉

【杂记】裂脑人实验和语言模型幻觉 模型的自主意识在哪里,人的自我认知在哪里?自然而然的,“裂脑人” 这个词突然出现在我脑海里。然后随意翻了翻相关的文章,觉得这个问题和目前大模型面临的幻觉问题也高度相关,遂随笔记录。 裂脑人 什么是裂脑人?人的大脑左右半脑本来是一个整体,因为先天或者后天的原因让左右半脑分开不产生连接,就是“裂脑”。过去这个方法被作为控制恶性癫痫的治疗手段。 一些铺垫知识

torch.backends.cudnn.benchmark和torch.use_deterministic_algorithms总结学习记录

经常使用PyTorch框架的应该对于torch.backends.cudnn.benchmark和torch.use_deterministic_algorithms这两个语句并不陌生,在以往开发项目的时候可能专门化花时间去了解过,也可能只是浅尝辄止简单有关注过,正好今天再次遇到了就想着总结梳理一下。 torch.backends.cudnn.benchmark 是 PyTorch 中的一个设置

_no_init的作用

__no_init用于禁止系统启动时的变量初始化,什么情况下需要用这个关键字使系统禁止变量的初始化,禁止变量初始化用在什么场合,为什么要这样做,有什么意义吗? 1、看门狗复位的现场恢复,如果初始化了就完全不可恢复了 2、使用nvram保存数据,需要连续记录的。    我有个变量,需要在系统意外复位时,这个变量值能保留,所以采用__no_init来实现,只是上电的时候这个值不是零

力士乐驱动主板CSB01.1N-SE-ENS-NNN-NN-S-N-FW

力士乐驱动主板CSB01.1N-SE-ENS-NNN-NN-S-N-FW ‌力士乐驱动器的使用说明主要涉及软件安装、参数配置、PID调节等方面。‌  ‌软件安装‌:安装过程涉及多个步骤,首先需要打开安装文件夹中的CD1,双击setup.exe进行安装。在安装过程中,需要选择语言、接受许可协议、输入安装名称、选择安装目录等。整个安装过程可能需要10多分钟,取决于电脑性能。安装完成后,需要重启计算