C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析

本文主要是介绍C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  在之前的博文中,我们已经队大部分层结构类都进行了分析,在这篇博文中我们准备针对最后两个,也是处于层结构类继承体系中最底层的两个基类layer_base和layer做一下简要分析。由于layer类只是对layer_base的一个简单实例化,因此这里着重分析layer_base类。

  首先,给出layer_base类的基本结构框图:

  一、成员变量

  由于layer_base是这个类体系结构的基类,是构建网络层的基石,因此其内部封装了网络层的基本属性,相应的也有大量对应的成员变量:

  接下来一一对这些成员变量的基本含义做一下大致介绍:

  (1)in_size_、out_size_:保存了当前层的输入数据尺寸和输出数据尺寸。

  (2)parallelize_:布尔类型标志位,用以标记当前工程是否使用TBB多线程加速。

  (3)next_、prev_:两个指向layer_base类型的指针,用以指向当前层的下一层以及当前层的上一层,是维持层间联系的关键纽带。

  (4)a_:保留当前层卷积运算的中间结果。

  (5)output_:经过激活函数处理之后的当前层的最终特征输出。

  (6)prev_delta_:有前一层传播过来的误差灵敏度(梯度下降法过程中使用)。

  (7)W_、b_:当前层的卷积核权重以及偏置。

  (8)dW_、db_:权重的导数和偏置的导数,用以对权重和偏置进行更新。

  (9)Whessian_、bhessian_:海森矩阵的相关变量,具体含义在后续博文中会详细解释。

  (10)prev_delta2_:误差相对于输入的二阶导数,主要用于全连接层中的误差计算。

  二、构造函数

  构造函数的功能十分简单,通过调用set_size()成员函数来完成网络层中各个相关变量的初始化: 

layer_base(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) : parallelize_(true), next_(nullptr), prev_(nullptr) 
{set_size(in_dim, out_dim, weight_dim, bias_dim);//初始化神经网络层的参数
}

   需要注意的一点是这里默认将parallelize_标志位初始化为true,即默认使用TBB加速。至于set_size()函数,主要是通过调用vector的成员函数resize()来对各个参数进行初始化。

  三、权重初始化

  权重初始化主要通过set_size()函数完成(注意,这个函数不仅仅在构造函数中有所调用),正如上文所说,这个函数本质上就是在调用resize():

        void set_size(layer_size_t in_dim, layer_size_t out_dim, size_t weight_dim, size_t bias_dim) {in_size_ = in_dim;out_size_ = out_dim;W_.resize(weight_dim);b_.resize(bias_dim);Whessian_.resize(weight_dim);bhessian_.resize(bias_dim);prev_delta2_.resize(in_dim);for (auto& o : output_)     o.resize(out_dim);for (auto& a : a_)          a.resize(out_dim);for (auto& p : prev_delta_) p.resize(in_dim);for (auto& dw : dW_) dw.resize(weight_dim);for (auto& db : db_) db.resize(bias_dim);}

  需要注意的一点就是这里使用了范围for循环来完成这个vector容器中元素的遍历和操作,这算是C++11的一个特点,需要慢慢体会,不过单纯的从遍历的角度讲,这的确比传统的for循环更为方便而安全。

  四、纯虚函数集

  由于layer_base是一个公共基类,有必要定义一些虚函数以及纯虚函数供派生出来的不同类型的子类进行改写。这里作者选择将与激活函数和前向/反向传播算法定义成纯虚函数,原因很明确:不同层的前向/反向传播算法是不同的,并且激活函数也是可有可无: 

        /**********将激活函数、前向传播和反向传播全部声明为纯虚函数,在子类中进行定义**********/virtual activation::function& activation_function() = 0;virtual const vec_t& forward_propagation(const vec_t& in, size_t worker_index) = 0;virtual const vec_t& back_propagation(const vec_t& current_delta, size_t worker_index) = 0;virtual const vec_t& back_propagation_2nd(const vec_t& current_delta2) = 0;

   五、中间状态保存

  由于卷积神经网络的训练时间都较长,因此有必要定义保存中间训练结果的接口以完成断点续传(这个用词可能不太恰当),因此在layer_base中提供了用以保存和加载网络中间训练状态的结构函数save和load:

        /**********保存网络层中的权重和偏置(中间训练结果)**********/virtual void save(std::ostream& os) const {if (is_exploded()) throw nn_error("failed to save weights because of infinite weight");for (auto w : W_) os << w << " ";for (auto b : b_) os << b << " ";}/**********加载中间训练值**********/virtual void load(std::istream& is) {for (auto& w : W_) is >> w;for (auto& b : b_) is >> b;}

   这里主要通过流操作来完成结果的输入输出操作,同样体现出了强力的C++特性。

  六、权值更新

  layer_base对权值更新的操作主要有两个,一是权值和偏置的参数的初始化操作set_size(),这个前文已经介绍过了;二是更新函数update_weight()。update_weight()函数主要是通过调用各个收敛算法(如这里默认使用的gradient_descent_levenberg_marquardt算法)中的update()函数来完成对应权值和偏置的更新操作:

  至于update函数的具体实现细节则取决于所使用的收敛算法,有关这部分内容我会在之后介绍收敛算法(Optimizer结构体)的博文中专门进行详细的介绍。不过从表面的调用形式上可以看出,在BP算法对权值进行更新的过程中,需要用到dW(一阶导数)和海森矩阵(二阶导数)。

  七、属性返回参数

  这部分结构函数几乎是各个网络层的必备函数,方便用户查看对应网络层的具体参数信息和特征输出结果,一般都包含两个方面,return语句和output_to_image类型的视觉转换函数。return语句负责返回网络层的相关成员变量(可以在内部进行一些简单运算),output_to_image()函数则负责将映射核、特征输出结果转换成图像的形式供我们观赏,这些在之前的博文中都有提到过,这里不再赘述。

  八、layer类结构分析

  相对于layer_base类,layer的结构功能则简单了很多,大体上可以分为三类。激活函数实例化,保存/加载函数具体化,定义错误提示信息。

  8.1 激活函数实例化

  由于在layer_base类中将激活函数定义为纯虚函数,作者选择在子类layer中对其进行实例化:

  这里涉及到了Activation类的使用,在这个类中封装了各种各样类型的激活函数,在后续的博文中会专门拿出一两篇的篇幅来对这个类进行分析。

  8.2 保存、加载中间训练值函数具体化

  这里没什么可细说的,通过流操作basic_ostream来进行输入输出:

    /**********辅助的保存、加载操作**********/template <typename Char, typename CharTraits>std::basic_ostream<Char, CharTraits>& operator << (std::basic_ostream<Char, CharTraits>& os, const layer_base& v) {v.save(os);return os;}template <typename Char, typename CharTraits>std::basic_istream<Char, CharTraits>& operator >> (std::basic_istream<Char, CharTraits>& os, layer_base& v) {v.load(os);return os;}

  8.3 错误提示函数定义

  在layer中定义了三种错误类型的信息提示函数:连接不匹配、输入特征维数不匹配、下采样维数不匹配:

  (1)连接不匹配信息提示函数connection_mismatch。这个函数主要是在程序发现当前一层的特征输出维数与后一层的特征输入维数不同时调用,格式化输出错误信息,指明出现问题的具体层。

  (2)输入特征维数不匹配信息提示函数data_mismatch:这个函数主要是在程序发现输入数据的维数与当前层的输入维数不匹配时调用,格式化输出错误信息,指明出现问题的具体层。

  (3)下采样维数不匹配信息提示函数pooling_size_mismatch:这个函数主要是在程序发现当前特征维数不能被下采样窗口尺寸整除时调用,格式化输出错误信息,指明出现问题的具体层。

  需要注意的一点是,以上三个函数只负责格式化输出错误信息提示,具体错误检查机制需要在对应的可能的调用环境中中自行编写进行判断。

  九、注意事项

  1、范围for循环

  在tiny_cnn工程中对容器进行遍历时,全部采用了范围for循环,这点对于之前一直使用传统for循环的童鞋来说刚开始可能有点难以接受,但毕竟范围for循环既安全又简答,以后也要多多使用。

  2、layer_base的函数并没有介绍完全

  上文中对layer_base类中的成员函数并没有百分之百的介绍完全,对于一些小的补丁试的成员函数在后续用到时再进行解释。

  3、激活函数不等于收敛算法

  这里强调一个初学者容易混淆的概念,就是激活函数和收敛算法。首先这两者是完全不同的,举个栗子通俗的说明一下:激活函数包含sigmoid,tanh,Relu;收敛算法则主要指梯度下降法,怎么样,是不是茅塞顿开了。

 



如果觉得这篇文章对您有所启发,欢迎关注我的公众号,我会尽可能积极和大家交流,谢谢。


这篇关于C++卷积神经网络实例:tiny_cnn代码详解(10)——layer_base和layer类结构分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115504

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

【C++ Primer Plus习题】13.4

大家好,这里是国中之林! ❥前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。有兴趣的可以点点进去看看← 问题: 解答: main.cpp #include <iostream>#include "port.h"int main() {Port p1;Port p2("Abc", "Bcc", 30);std::cout <<

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

C++包装器

包装器 在 C++ 中,“包装器”通常指的是一种设计模式或编程技巧,用于封装其他代码或对象,使其更易于使用、管理或扩展。包装器的概念在编程中非常普遍,可以用于函数、类、库等多个方面。下面是几个常见的 “包装器” 类型: 1. 函数包装器 函数包装器用于封装一个或多个函数,使其接口更统一或更便于调用。例如,std::function 是一个通用的函数包装器,它可以存储任意可调用对象(函数、函数

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

C++11第三弹:lambda表达式 | 新的类功能 | 模板的可变参数

🌈个人主页: 南桥几晴秋 🌈C++专栏: 南桥谈C++ 🌈C语言专栏: C语言学习系列 🌈Linux学习专栏: 南桥谈Linux 🌈数据结构学习专栏: 数据结构杂谈 🌈数据库学习专栏: 南桥谈MySQL 🌈Qt学习专栏: 南桥谈Qt 🌈菜鸡代码练习: 练习随想记录 🌈git学习: 南桥谈Git 🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈�

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

06 C++Lambda表达式

lambda表达式的定义 没有显式模版形参的lambda表达式 [捕获] 前属性 (形参列表) 说明符 异常 后属性 尾随类型 约束 {函数体} 有显式模版形参的lambda表达式 [捕获] <模版形参> 模版约束 前属性 (形参列表) 说明符 异常 后属性 尾随类型 约束 {函数体} 含义 捕获:包含零个或者多个捕获符的逗号分隔列表 模板形参:用于泛型lambda提供个模板形参的名

usaco 1.3 Mixing Milk (结构体排序 qsort) and hdu 2020(sort)

到了这题学会了结构体排序 于是回去修改了 1.2 milking cows 的算法~ 结构体排序核心: 1.结构体定义 struct Milk{int price;int milks;}milk[5000]; 2.自定义的比较函数,若返回值为正,qsort 函数判定a>b ;为负,a<b;为0,a==b; int milkcmp(const void *va,c