Caffe使用——01 以LeNet训练Mnist数据集为例

2023-11-08 18:58

本文主要是介绍Caffe使用——01 以LeNet训练Mnist数据集为例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 CNN训练初体验(使用几个命令来训练手写数字数据集)

1.1 下载数据、转换数据格式

设CAFFE_ROOT为caffe的安装路径。

cd $CAFFE_ROOT
./data/mnist/get_mnist.sh
./examples/mnist/create_mnist.sh

上述脚本中的内容完成的工作就是下载并转换数据,暂不做详细介绍。

1.2 训练

cd $CAFFE_ROOT
./examples/mnist/train_lenet.sh

训练命令:

caffe train -solver lenet_solver.prototxt -gpu 0 -log_dir ./

caffe命令参数解释:

commands

train 训练和微调一个模型
test 对一个模型打分
device_query 显示GPU诊断信息
time 评估模型执行时间

flags

gpu : 指定用哪块GPU训练
model : 模型定义文件
log_dir : 指定log文件输出的路径。(这个路径必须事先存在)
weights : 用已经训练好的模型来初始化参数。
snapshot : 从之前训练的某个solver 状态恢复训练。
iterations : 和solver中的test_iter类似,运行迭代次数。
sighup_effect : 当收到SIGHUP信号时采取的动作,可选项:snap/stop/none。默认为snapshot,即打快照。
sigint_effect : 当收到SIGINT信号时要采取的动作,可选项同上,默认为stop。
solver : 指定求解器文本文件名。

1.3 评估模型性能

caffe time -model lenet.prototxt -gpu 0

2 求解器(solver)——训练超参数

查看训练脚本:

➜  caffe git:(zxdev_mac) cat ./examples/mnist/train_lenet.sh
#!/usr/bin/env sh
set -e./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

查看solver.prototxt

➜  caffe git:(zxdev_mac) cat examples/mnist/lenet_solver.prototxt
# The train/test net protocol buffer definition
# 用于训练测试的网络结构文件
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
# test_iter 指定test执行的时候迭代次数
test_iter: 100
# Carry out testing every 500 training iterations.
# 每训练500次执行一次test
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
# 网络的基础学习率,冲量,权衰量
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
# inv 的学习策略,lr = base_lr * (1 + gamma * iter) ^ (-power)
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
# 每迭代多少次显示 一次当前训练的信息,主要是loss和学习率
display: 100
# The maximum number of iterations
# 指定最大迭代次数
max_iter: 10000
# snapshot intermediate results
# 每迭代多少次保存一次模型的参数和训练状态。
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: GPU

3 定义网络结构 lenet_train_val.prototxt

网络结构定义在examples/mnist/lenet_train_test.prototxt中。

➜  caffe git:(zxdev_mac) cat examples/mnist/lenet_train_test.prototxt
# 网络(net)的名称为LeNet
name: "LeNet"
layer {# 这一层的名字是mnistname: "mnist"# 这一层的类型是Datao数据层type: "Data"# 这一层产生两个blobs,分别是data blob和label blobtop: "data"top: "label"include {# 该层参数 只在训练阶段有效phase: TRAIN}transform_param {# 此处还可添加mean_value,数据先减mean_value,再乘scale。注意若有此项,需要在inference时减均值。# mean_value: 128# 1/256.0 = 0.00390625,像素值控制在0到1之间。scale: 0.00390625}data_param {source: "examples/mnist/mnist_train_lmdb"# 指定训练阶段,每次迭代用50个。batch_size: 64backend: LMDB}
}
layer {name: "mnist"type: "Data"top: "data"top: "label"include {phase: TEST}transform_param {scale: 0.00390625}data_param {source: "examples/mnist/mnist_test_lmdb"batch_size: 100backend: LMDB}
}
layer {name: "conv1"type: "Convolution"bottom: "data"top: "conv1"# 卷积核学习率为基础学习率乘以 lr_multparam {lr_mult: 1}# 偏置学习率为基础学习率乘以 lr_multparam {lr_mult: 2}convolution_param {# 输出20个通道num_output: 20# 卷积核尺寸是5kernel_size: 5# 步长是1stride: 1# 随机初始化权重,用xavier算法,自动根据输入输出的数量来定初始化的比例weight_filler {type: "xavier"}# bais使用常数,默认用0填充。bias_filler {type: "constant"}}
}
layer {name: "pool1"type: "Pooling"bottom: "conv1"top: "pool1"pooling_param {# 采用最大值下采样pool: MAX# 池化核尺寸为2,步长为2kernel_size: 2stride: 2}
}
layer {name: "conv2"type: "Convolution"bottom: "pool1"top: "conv2"param {lr_mult: 1}param {lr_mult: 2}convolution_param {num_output: 50kernel_size: 5stride: 1weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
layer {name: "pool2"type: "Pooling"bottom: "conv2"top: "pool2"pooling_param {pool: MAXkernel_size: 2stride: 2}
}
layer {name: "ip1"type: "InnerProduct"bottom: "pool2"top: "ip1"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 500weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
layer {name: "relu1"type: "ReLU"bottom: "ip1"top: "ip1"
}
layer {name: "ip2"type: "InnerProduct"bottom: "ip1"top: "ip2"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 10weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
# 分类准确率层,只在测试阶段有效。用于计算分类的准确率
layer {name: "accuracy"type: "Accuracy"bottom: "ip2"bottom: "label"top: "accuracy"include {phase: TEST}
}
layer {name: "loss"type: "SoftmaxWithLoss"# 没有输出,只是计算lossbottom: "ip2"bottom: "label"top: "loss"
}

4 查看训练过程中的准确率和loss

将log_dir指定路径下的日志重命名后缀为log,例如mnist_train.log。
在log_dir下生成准确率图片:

../tools/extra/plot_training_log.py.example 0 test_acc_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 2 test_loss_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 6 train_acc_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 4 lr_vs_iters.png mnist_train.log

这篇关于Caffe使用——01 以LeNet训练Mnist数据集为例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/371841

相关文章

一文详解如何使用Java获取PDF页面信息

《一文详解如何使用Java获取PDF页面信息》了解PDF页面属性是我们在处理文档、内容提取、打印设置或页面重组等任务时不可或缺的一环,下面我们就来看看如何使用Java语言获取这些信息吧... 目录引言一、安装和引入PDF处理库引入依赖二、获取 PDF 页数三、获取页面尺寸(宽高)四、获取页面旋转角度五、判断

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态

prometheus如何使用pushgateway监控网路丢包

《prometheus如何使用pushgateway监控网路丢包》:本文主要介绍prometheus如何使用pushgateway监控网路丢包问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录监控网路丢包脚本数据图表总结监控网路丢包脚本[root@gtcq-gt-monitor-prome

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期