maxout简单理解

2023-11-11 08:18
文章标签 简单 理解 maxout

本文主要是介绍maxout简单理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

maxout出现在ICML2013上,作者Goodfellow将maxout和dropout结合后,号称在MNIST, CIFAR-10, CIFAR-100, SVHN这4个数据上都取得了start-of-art的识别率。

  从论文中可以看出,maxout其实一种激发函数形式。通常情况下,如果激发函数采用sigmoid函数的话,在前向传播过程中,隐含层节点的输出表达式为:

   

  其中W一般是2维的,这里表示取出的是第i列,下标i前的省略号表示对应第i列中的所有行。但如果是maxout激发函数,则其隐含层节点的输出表达式为:

    

  

  这里的W是3维的,尺寸为d*m*k,其中d表示输入层节点的个数,m表示隐含层节点的个数,k表示每个隐含层节点对应了k个”隐隐含层”节点,这k个”隐隐含层”节点都是线性输出的,而maxout的每个节点就是取这k个”隐隐含层”节点输出值中最大的那个值。因为激发函数中有了max操作,所以整个maxout网络也是一种非线性的变换。因此当我们看到常规结构的神经网络时,如果它使用了maxout激发,则我们头脑中应该自动将这个”隐隐含层”节点加入。参考一个日文的maxout ppt 中的一页ppt如下:

   

  ppt中箭头前后示意图大家应该可以明白什么是maxout激发函数了。

  maxout的拟合能力是非常强的,它可以拟合任意的的凸函数。最直观的解释就是任意的凸函数都可以由分段线性函数以任意精度拟合(学过高等数学应该能明白),而maxout又是取k个隐隐含层节点的最大值,这些”隐隐含层"节点也是线性的,所以在不同的取值范围下,最大值也可以看做是分段线性的(分段的个数与k值有关)。论文中的图1如下(它表达的意思就是可以拟合任意凸函数,当然也包括了ReLU了):

   

  作者从数学的角度上也证明了这个结论,即只需2个maxout节点就可以拟合任意的凸函数了(相减),前提是”隐隐含层”节点的个数可以任意多,如下图所示:

   

  下面来看下maxout源码,看其激发函数表达式是否符合我们的理解。找到库目录下的pylearn2/models/maxout.py文件,选择不带卷积的Maxout类,主要是其前向传播函数fprop():

复制代码
  def fprop(self, state_below): #前向传播,对linear分组进行max-pooling操作self.input_space.validate(state_below)if self.requires_reformat:if not isinstance(state_below, tuple):for sb in get_debug_values(state_below):if sb.shape[0] != self.dbm.batch_size:raise ValueError("self.dbm.batch_size is %d but got shape of %d" % (self.dbm.batch_size, sb.shape[0]))assert reduce(lambda x,y: x * y, sb.shape[1:]) == self.input_dimstate_below = self.input_space.format_as(state_below, self.desired_space) #统一好输入数据的格式z = self.transformer.lmul(state_below) + self.b # lmul()函数返回的是 return T.dot(x, self._W)if not hasattr(self, 'randomize_pools'):self.randomize_pools = Falseif not hasattr(self, 'pool_stride'):self.pool_stride = self.pool_size #默认情况下是没有重叠的poolingif self.randomize_pools:z = T.dot(z, self.permute)if not hasattr(self, 'min_zero'):self.min_zero = Falseif self.min_zero:p = T.zeros_like(z) #返回一个和z同样大小的矩阵,元素值为0,元素值类型和z的类型一样else:p = Nonelast_start = self.detector_layer_dim  - self.pool_sizefor i in xrange(self.pool_size): #xrange和reange的功能类似cur = z[:,i:last_start+i+1:self.pool_stride]  # L[start:end:step]是用来切片的,从[start,end)之间,每隔step取一次if p is None:p = curelse:p = T.maximum(cur, p) #将p进行迭代比较,因为每次取的是每个group里的元素,所以进行pool_size次后就可以获得每个group的最大值p.name = self.layer_name + '_p_'return p
复制代码

  仔细阅读上面的源码,发现和文章中描述基本是一致的,只是多了很多细节。

  由于没有GPU,所以只用CPU 跑了个mnist的简单实验,参考:maxout下的readme文件。(需先下载mnist dataset到PYLEARN2_DATA_PATA目录下)。

  执行../../train.py minist_pi.yaml

  此时的.yaml配置文件内容如下:

复制代码
!obj:pylearn2.train.Train {dataset: &train !obj:pylearn2.datasets.mnist.MNIST {which_set: 'train',one_hot: 1,start: 0,stop: 50000},model: !obj:pylearn2.models.mlp.MLP {layers: [!obj:pylearn2.models.maxout.Maxout {layer_name: 'h0',num_units: 240,num_pieces: 5,irange: .005,max_col_norm: 1.9365,},!obj:pylearn2.models.maxout.Maxout {layer_name: 'h1',num_units: 240,num_pieces: 5,irange: .005,max_col_norm: 1.9365,},!obj:pylearn2.models.mlp.Softmax {max_col_norm: 1.9365,layer_name: 'y',n_classes: 10,irange: .005}],nvis: 784,},algorithm: !obj:pylearn2.training_algorithms.sgd.SGD {batch_size: 100,learning_rate: .1,learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum {init_momentum: .5,},monitoring_dataset:{'train' : *train,'valid' : !obj:pylearn2.datasets.mnist.MNIST {which_set: 'train',one_hot: 1,start: 50000,stop:  60000},'test'  : !obj:pylearn2.datasets.mnist.MNIST {which_set: 'test',one_hot: 1,}},cost: !obj:pylearn2.costs.mlp.dropout.Dropout {input_include_probs: { 'h0' : .8 },input_scales: { 'h0': 1. }},termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased {channel_name: "valid_y_misclass",prop_decrease: 0.,N: 100},update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay {decay_factor: 1.000004,min_lr: .000001}},extensions: [!obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest {channel_name: 'valid_y_misclass',save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}_best.pkl"},!obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor {start: 1,saturate: 250,final_momentum: .7}],save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl",save_freq: 1
}
复制代码

  跑了一个晚上才迭代了210次,被我kill掉了(笔记本还得拿到别的地方干活),这时的误差率为1.22%。估计继续跑几个小时应该会降到作者的0.94%误差率。

  其monitor监控输出结果如下:

复制代码
Monitoring step:Epochs seen: 210Batches seen: 105000Examples seen: 10500000learning_rate: 0.0657047371741momentum: 0.667871485944monitor_seconds_per_epoch: 121.0test_h0_col_norms_max: 1.9364999test_h0_col_norms_mean: 1.09864382902test_h0_col_norms_min: 0.0935518826938test_h0_p_max_x.max_u: 3.97355476543test_h0_p_max_x.mean_u: 2.14463905251test_h0_p_max_x.min_u: 0.961549570265test_h0_p_mean_x.max_u: 0.878285389379test_h0_p_mean_x.mean_u: 0.131020009421test_h0_p_mean_x.min_u: -0.373017504665test_h0_p_min_x.max_u: -0.202480633479test_h0_p_min_x.mean_u: -1.31821964107test_h0_p_min_x.min_u: -2.52428183099test_h0_p_range_x.max_u: 5.56309069078test_h0_p_range_x.mean_u: 3.46285869357test_h0_p_range_x.min_u: 2.01775637301test_h0_row_norms_max: 2.67556467test_h0_row_norms_mean: 1.15743973628test_h0_row_norms_min: 0.0951322935423test_h1_col_norms_max: 1.12119975186test_h1_col_norms_mean: 0.595629304226test_h1_col_norms_min: 0.183531862659test_h1_p_max_x.max_u: 6.42944749321test_h1_p_max_x.mean_u: 3.74599401756test_h1_p_max_x.min_u: 2.03028191814test_h1_p_mean_x.max_u: 1.38424650414test_h1_p_mean_x.mean_u: 0.583690886644test_h1_p_mean_x.min_u: 0.0253866100292test_h1_p_min_x.max_u: -0.830110300894test_h1_p_min_x.mean_u: -1.73539242398test_h1_p_min_x.min_u: -3.03677525979test_h1_p_range_x.max_u: 8.63650239768test_h1_p_range_x.mean_u: 5.48138644154test_h1_p_range_x.min_u: 3.36428499068test_h1_row_norms_max: 1.95904749183test_h1_row_norms_mean: 1.40561339238test_h1_row_norms_min: 1.16953677471test_objective: 0.0959691806325test_y_col_norms_max: 1.93642459019test_y_col_norms_mean: 1.90996961714test_y_col_norms_min: 1.88659811751test_y_max_max_class: 1.0test_y_mean_max_class: 0.996910632311test_y_min_max_class: 0.824416386342test_y_misclass: 0.0114test_y_nll: 0.0609837733094test_y_row_norms_max: 0.536167736581test_y_row_norms_mean: 0.386866656967test_y_row_norms_min: 0.266996530755train_h0_col_norms_max: 1.9364999train_h0_col_norms_mean: 1.09864382902train_h0_col_norms_min: 0.0935518826938train_h0_p_max_x.max_u: 3.98463017313train_h0_p_max_x.mean_u: 2.16546276053train_h0_p_max_x.min_u: 0.986865505974train_h0_p_mean_x.max_u: 0.850944629066train_h0_p_mean_x.mean_u: 0.135825383808train_h0_p_mean_x.min_u: -0.354841456train_h0_p_min_x.max_u: -0.20750516843train_h0_p_min_x.mean_u: -1.32748375925train_h0_p_min_x.min_u: -2.49716541111train_h0_p_range_x.max_u: 5.61263186775train_h0_p_range_x.mean_u: 3.49294651978train_h0_p_range_x.min_u: 2.07324073262train_h0_row_norms_max: 2.67556467train_h0_row_norms_mean: 1.15743973628train_h0_row_norms_min: 0.0951322935423train_h1_col_norms_max: 1.12119975186train_h1_col_norms_mean: 0.595629304226train_h1_col_norms_min: 0.183531862659train_h1_p_max_x.max_u: 6.49689754011train_h1_p_max_x.mean_u: 3.77637040198train_h1_p_max_x.min_u: 2.03274038543train_h1_p_mean_x.max_u: 1.34966894021train_h1_p_mean_x.mean_u: 0.57555584546train_h1_p_mean_x.min_u: 0.0176827309146train_h1_p_min_x.max_u: -0.845786992369train_h1_p_min_x.mean_u: -1.74696425227train_h1_p_min_x.min_u: -3.05703072635train_h1_p_range_x.max_u: 8.73556577905train_h1_p_range_x.mean_u: 5.52333465425train_h1_p_range_x.min_u: 3.379501944train_h1_row_norms_max: 1.95904749183train_h1_row_norms_mean: 1.40561339238train_h1_row_norms_min: 1.16953677471train_objective: 0.0119584870103train_y_col_norms_max: 1.93642459019train_y_col_norms_mean: 1.90996961714train_y_col_norms_min: 1.88659811751train_y_max_max_class: 1.0train_y_mean_max_class: 0.999958965285train_y_min_max_class: 0.996295480193train_y_misclass: 0.0train_y_nll: 4.22109408992e-05train_y_row_norms_max: 0.536167736581train_y_row_norms_mean: 0.386866656967train_y_row_norms_min: 0.266996530755valid_h0_col_norms_max: 1.9364999valid_h0_col_norms_mean: 1.09864382902valid_h0_col_norms_min: 0.0935518826938valid_h0_p_max_x.max_u: 3.970333514valid_h0_p_max_x.mean_u: 2.15548653063valid_h0_p_max_x.min_u: 0.99228626325valid_h0_p_mean_x.max_u: 0.84583547397valid_h0_p_mean_x.mean_u: 0.143554208322valid_h0_p_mean_x.min_u: -0.349097300524valid_h0_p_min_x.max_u: -0.218285757389valid_h0_p_min_x.mean_u: -1.28008164111valid_h0_p_min_x.min_u: -2.41494612443valid_h0_p_range_x.max_u: 5.54136030367valid_h0_p_range_x.mean_u: 3.43556817173valid_h0_p_range_x.min_u: 2.03580165751valid_h0_row_norms_max: 2.67556467valid_h0_row_norms_mean: 1.15743973628valid_h0_row_norms_min: 0.0951322935423valid_h1_col_norms_max: 1.12119975186valid_h1_col_norms_mean: 0.595629304226valid_h1_col_norms_min: 0.183531862659valid_h1_p_max_x.max_u: 6.4820340666valid_h1_p_max_x.mean_u: 3.75160795812valid_h1_p_max_x.min_u: 2.00587987424valid_h1_p_mean_x.max_u: 1.38777592924valid_h1_p_mean_x.mean_u: 0.578550013139valid_h1_p_mean_x.min_u: 0.0232071426066valid_h1_p_min_x.max_u: -0.84151110053valid_h1_p_min_x.mean_u: -1.73734213646valid_h1_p_min_x.min_u: -3.09680505839valid_h1_p_range_x.max_u: 8.72732563235valid_h1_p_range_x.mean_u: 5.48895009458valid_h1_p_range_x.min_u: 3.32030803638valid_h1_row_norms_max: 1.95904749183valid_h1_row_norms_mean: 1.40561339238valid_h1_row_norms_min: 1.16953677471valid_objective: 0.104670540623valid_y_col_norms_max: 1.93642459019valid_y_col_norms_mean: 1.90996961714valid_y_col_norms_min: 1.88659811751valid_y_max_max_class: 1.0valid_y_mean_max_class: 0.99627268242valid_y_min_max_class: 0.767024730168valid_y_misclass: 0.0122valid_y_nll: 0.0682986195071valid_y_row_norms_max: 0.536167736581valid_y_row_norms_mean: 0.38686665696valid_y_row_norms_min: 0.266996530755
Saving to mnist_pi.pkl...
Saving to mnist_pi.pkl done. Time elapsed: 3.000000 seconds
Time this epoch: 0:02:08.747395

这篇关于maxout简单理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/388734

相关文章

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu2289(简单二分)

虽说是简单二分,但是我还是wa死了  题意:已知圆台的体积,求高度 首先要知道圆台体积怎么求:设上下底的半径分别为r1,r2,高为h,V = PI*(r1*r1+r1*r2+r2*r2)*h/3 然后以h进行二分 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#includ

usaco 1.3 Prime Cryptarithm(简单哈希表暴搜剪枝)

思路: 1. 用一个 hash[ ] 数组存放输入的数字,令 hash[ tmp ]=1 。 2. 一个自定义函数 check( ) ,检查各位是否为输入的数字。 3. 暴搜。第一行数从 100到999,第二行数从 10到99。 4. 剪枝。 代码: /*ID: who jayLANG: C++TASK: crypt1*/#include<stdio.h>bool h

uva 10387 Billiard(简单几何)

题意是一个球从矩形的中点出发,告诉你小球与矩形两条边的碰撞次数与小球回到原点的时间,求小球出发时的角度和小球的速度。 简单的几何问题,小球每与竖边碰撞一次,向右扩展一个相同的矩形;每与横边碰撞一次,向上扩展一个相同的矩形。 可以发现,扩展矩形的路径和在当前矩形中的每一段路径相同,当小球回到出发点时,一条直线的路径刚好经过最后一个扩展矩形的中心点。 最后扩展的路径和横边竖边恰好组成一个直

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

poj 1113 凸包+简单几何计算

题意: 给N个平面上的点,现在要在离点外L米处建城墙,使得城墙把所有点都包含进去且城墙的长度最短。 解析: 韬哥出的某次训练赛上A出的第一道计算几何,算是大水题吧。 用convexhull算法把凸包求出来,然后加加减减就A了。 计算见下图: 好久没玩画图了啊好开心。 代码: #include <iostream>#include <cstdio>#inclu

uva 10130 简单背包

题意: 背包和 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#include <stack>#include <vector>#include <queue>#include <map>

【C++高阶】C++类型转换全攻略:深入理解并高效应用

📝个人主页🌹:Eternity._ ⏩收录专栏⏪:C++ “ 登神长阶 ” 🤡往期回顾🤡:C++ 智能指针 🌹🌹期待您的关注 🌹🌹 ❀C++的类型转换 📒1. C语言中的类型转换📚2. C++强制类型转换⛰️static_cast🌞reinterpret_cast⭐const_cast🍁dynamic_cast 📜3. C++强制类型转换的原因📝

深入理解RxJava:响应式编程的现代方式

在当今的软件开发世界中,异步编程和事件驱动的架构变得越来越重要。RxJava,作为响应式编程(Reactive Programming)的一个流行库,为Java和Android开发者提供了一种强大的方式来处理异步任务和事件流。本文将深入探讨RxJava的核心概念、优势以及如何在实际项目中应用它。 文章目录 💯 什么是RxJava?💯 响应式编程的优势💯 RxJava的核心概念