Caffe使用——01 以LeNet训练Mnist数据集为例

2023-11-08 18:58

本文主要是介绍Caffe使用——01 以LeNet训练Mnist数据集为例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 CNN训练初体验(使用几个命令来训练手写数字数据集)

1.1 下载数据、转换数据格式

设CAFFE_ROOT为caffe的安装路径。

cd $CAFFE_ROOT
./data/mnist/get_mnist.sh
./examples/mnist/create_mnist.sh

上述脚本中的内容完成的工作就是下载并转换数据,暂不做详细介绍。

1.2 训练

cd $CAFFE_ROOT
./examples/mnist/train_lenet.sh

训练命令:

caffe train -solver lenet_solver.prototxt -gpu 0 -log_dir ./

caffe命令参数解释:

commands

train 训练和微调一个模型
test 对一个模型打分
device_query 显示GPU诊断信息
time 评估模型执行时间

flags

gpu : 指定用哪块GPU训练
model : 模型定义文件
log_dir : 指定log文件输出的路径。(这个路径必须事先存在)
weights : 用已经训练好的模型来初始化参数。
snapshot : 从之前训练的某个solver 状态恢复训练。
iterations : 和solver中的test_iter类似,运行迭代次数。
sighup_effect : 当收到SIGHUP信号时采取的动作,可选项:snap/stop/none。默认为snapshot,即打快照。
sigint_effect : 当收到SIGINT信号时要采取的动作,可选项同上,默认为stop。
solver : 指定求解器文本文件名。

1.3 评估模型性能

caffe time -model lenet.prototxt -gpu 0

2 求解器(solver)——训练超参数

查看训练脚本:

➜  caffe git:(zxdev_mac) cat ./examples/mnist/train_lenet.sh
#!/usr/bin/env sh
set -e./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

查看solver.prototxt

➜  caffe git:(zxdev_mac) cat examples/mnist/lenet_solver.prototxt
# The train/test net protocol buffer definition
# 用于训练测试的网络结构文件
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
# test_iter 指定test执行的时候迭代次数
test_iter: 100
# Carry out testing every 500 training iterations.
# 每训练500次执行一次test
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
# 网络的基础学习率,冲量,权衰量
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
# inv 的学习策略,lr = base_lr * (1 + gamma * iter) ^ (-power)
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
# 每迭代多少次显示 一次当前训练的信息,主要是loss和学习率
display: 100
# The maximum number of iterations
# 指定最大迭代次数
max_iter: 10000
# snapshot intermediate results
# 每迭代多少次保存一次模型的参数和训练状态。
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: GPU

3 定义网络结构 lenet_train_val.prototxt

网络结构定义在examples/mnist/lenet_train_test.prototxt中。

➜  caffe git:(zxdev_mac) cat examples/mnist/lenet_train_test.prototxt
# 网络(net)的名称为LeNet
name: "LeNet"
layer {# 这一层的名字是mnistname: "mnist"# 这一层的类型是Datao数据层type: "Data"# 这一层产生两个blobs,分别是data blob和label blobtop: "data"top: "label"include {# 该层参数 只在训练阶段有效phase: TRAIN}transform_param {# 此处还可添加mean_value,数据先减mean_value,再乘scale。注意若有此项,需要在inference时减均值。# mean_value: 128# 1/256.0 = 0.00390625,像素值控制在0到1之间。scale: 0.00390625}data_param {source: "examples/mnist/mnist_train_lmdb"# 指定训练阶段,每次迭代用50个。batch_size: 64backend: LMDB}
}
layer {name: "mnist"type: "Data"top: "data"top: "label"include {phase: TEST}transform_param {scale: 0.00390625}data_param {source: "examples/mnist/mnist_test_lmdb"batch_size: 100backend: LMDB}
}
layer {name: "conv1"type: "Convolution"bottom: "data"top: "conv1"# 卷积核学习率为基础学习率乘以 lr_multparam {lr_mult: 1}# 偏置学习率为基础学习率乘以 lr_multparam {lr_mult: 2}convolution_param {# 输出20个通道num_output: 20# 卷积核尺寸是5kernel_size: 5# 步长是1stride: 1# 随机初始化权重,用xavier算法,自动根据输入输出的数量来定初始化的比例weight_filler {type: "xavier"}# bais使用常数,默认用0填充。bias_filler {type: "constant"}}
}
layer {name: "pool1"type: "Pooling"bottom: "conv1"top: "pool1"pooling_param {# 采用最大值下采样pool: MAX# 池化核尺寸为2,步长为2kernel_size: 2stride: 2}
}
layer {name: "conv2"type: "Convolution"bottom: "pool1"top: "conv2"param {lr_mult: 1}param {lr_mult: 2}convolution_param {num_output: 50kernel_size: 5stride: 1weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
layer {name: "pool2"type: "Pooling"bottom: "conv2"top: "pool2"pooling_param {pool: MAXkernel_size: 2stride: 2}
}
layer {name: "ip1"type: "InnerProduct"bottom: "pool2"top: "ip1"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 500weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
layer {name: "relu1"type: "ReLU"bottom: "ip1"top: "ip1"
}
layer {name: "ip2"type: "InnerProduct"bottom: "ip1"top: "ip2"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 10weight_filler {type: "xavier"}bias_filler {type: "constant"}}
}
# 分类准确率层,只在测试阶段有效。用于计算分类的准确率
layer {name: "accuracy"type: "Accuracy"bottom: "ip2"bottom: "label"top: "accuracy"include {phase: TEST}
}
layer {name: "loss"type: "SoftmaxWithLoss"# 没有输出,只是计算lossbottom: "ip2"bottom: "label"top: "loss"
}

4 查看训练过程中的准确率和loss

将log_dir指定路径下的日志重命名后缀为log,例如mnist_train.log。
在log_dir下生成准确率图片:

../tools/extra/plot_training_log.py.example 0 test_acc_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 2 test_loss_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 6 train_acc_vs_iters.png mnist_train.log
../tools/extra/plot_training_log.py.example 4 lr_vs_iters.png mnist_train.log

这篇关于Caffe使用——01 以LeNet训练Mnist数据集为例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/371841

相关文章

使用Python合并 Excel单元格指定行列或单元格范围

《使用Python合并Excel单元格指定行列或单元格范围》合并Excel单元格是Excel数据处理和表格设计中的一项常用操作,本文将介绍如何通过Python合并Excel中的指定行列或单... 目录python Excel库安装Python合并Excel 中的指定行Python合并Excel 中的指定列P

浅析Rust多线程中如何安全的使用变量

《浅析Rust多线程中如何安全的使用变量》这篇文章主要为大家详细介绍了Rust如何在线程的闭包中安全的使用变量,包括共享变量和修改变量,文中的示例代码讲解详细,有需要的小伙伴可以参考下... 目录1. 向线程传递变量2. 多线程共享变量引用3. 多线程中修改变量4. 总结在Rust语言中,一个既引人入胜又可

一文详解Python中数据清洗与处理的常用方法

《一文详解Python中数据清洗与处理的常用方法》在数据处理与分析过程中,缺失值、重复值、异常值等问题是常见的挑战,本文总结了多种数据清洗与处理方法,文中的示例代码简洁易懂,有需要的小伙伴可以参考下... 目录缺失值处理重复值处理异常值处理数据类型转换文本清洗数据分组统计数据分箱数据标准化在数据处理与分析过

大数据小内存排序问题如何巧妙解决

《大数据小内存排序问题如何巧妙解决》文章介绍了大数据小内存排序的三种方法:数据库排序、分治法和位图法,数据库排序简单但速度慢,对设备要求高;分治法高效但实现复杂;位图法可读性差,但存储空间受限... 目录三种方法:方法概要数据库排序(http://www.chinasem.cn对数据库设备要求较高)分治法(常

golang1.23版本之前 Timer Reset方法无法正确使用

《golang1.23版本之前TimerReset方法无法正确使用》在Go1.23之前,使用`time.Reset`函数时需要先调用`Stop`并明确从timer的channel中抽取出东西,以避... 目录golang1.23 之前 Reset ​到底有什么问题golang1.23 之前到底应该如何正确的

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Python将大量遥感数据的值缩放指定倍数的方法(推荐)

《Python将大量遥感数据的值缩放指定倍数的方法(推荐)》本文介绍基于Python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处理,并将所得处理后数据保存为新的遥感影像... 本文介绍基于python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3