libtorch---day02[第一个分类器]

2024-09-01 08:04

本文主要是介绍libtorch---day02[第一个分类器],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考pytorch。

加载数据集CIFAR10

因为例程中使用的是torchvision加载数据集CIFAR10,但是torchvision的c++版本提供的功能太少,不考虑使用了,直接下载bin文件进行读取加载,CIFAR10数据格式:

CIFAR-10 数据集的图像数据以二进制格式存储。数据集中每张图像由以下几个部分组成:

  • 标签 (1 byte):

      表示图像所属的类别(0-9)。每个标签占用 1 个字节。
    
  • 图像数据 (3072 bytes):

      1、图像数据以 3072 字节的顺序存储,其中每个图像由 3 个通道(红色、绿色、蓝色)组成。2、每个通道的大小为 32 x 32 = 1024 字节,因此图像数据总大小为 1024 x 3 = 3072 字节。3、数据排列顺序为:红色通道的所有像素,接着是绿色通道的所有像素,最后是蓝色通道的所有像素。
    
  • 每个样本的二进制格式结构

字节 1: 标签 (0-9)
字节 2-1025: 红色通道的像素数据 (32x32)
字节 1026-2049: 绿色通道的像素数据 (32x32)
字节 2050-3073: 蓝色通道的像素数据 (32x32)

#define IMAGE_HEIGHT 32
#define IMAGE_WIDTH 32
// ...
void load_CIFAR(std::vector<std::pair<char, cv::Mat> >& images)
{std::ifstream fs(CIFAR_data_path.c_str(), std::ios::binary);if (fs.is_open()){while (!fs.eof()){char label;std::vector<uchar> image;image.resize(3 * IMAGE_HEIGHT * IMAGE_WIDTH);fs.read(&label, 1);fs.read(reinterpret_cast<char*>(image.data()), 3 * IMAGE_HEIGHT * IMAGE_WIDTH);cv::Mat image_cv(IMAGE_WIDTH, IMAGE_HEIGHT, CV_8UC3);
#pragma omp forfor (int i = 0; i < IMAGE_HEIGHT; i++){for (int j = 0; j < IMAGE_WIDTH; j++){image_cv.at<cv::Vec3b>(i, j)[0] = image[i * IMAGE_WIDTH + j];image_cv.at<cv::Vec3b>(i, j)[1] = image[i * IMAGE_WIDTH + j + IMAGE_HEIGHT * IMAGE_WIDTH];image_cv.at<cv::Vec3b>(i, j)[2] = image[i * IMAGE_WIDTH + j + IMAGE_HEIGHT * IMAGE_WIDTH * 2];}}images.push_back({ label, image_cv });}fs.close();}
}

构造网络

这里的torch::nn::Linear、torch::nn::Conv2d等等都是torch::nn::ModuleHolder,本质上都是智能指针,因此在定义的时候就需要给初始值。

class Classifier : public torch::nn::Module
{
public:torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };torch::nn::Conv2d conv1{ nullptr }, conv2{nullptr};torch::nn::MaxPool2d pool{ nullptr };Classifier() : Module(),fc1(torch::nn::Linear(torch::nn::LinearOptions(16 * 5 * 5, 120))),fc2(torch::nn::Linear(torch::nn::LinearOptions(120, 84))),fc3(torch::nn::Linear(torch::nn::LinearOptions(84, 10))),conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5))),conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5))),pool(torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))){register_module("conv1", conv1);register_module("conv2", conv2);register_module("pool", pool);register_module("fc1", fc1);register_module("fc2", fc2);register_module("fc3", fc3);}~Classifier(){}torch::Tensor forward(torch::Tensor x){x = pool->forward(torch::nn::functional::relu(conv1->forward(x)));x = pool->forward(torch::nn::functional::relu(conv2->forward(x)));x = torch::flatten(x, 1);x = torch::nn::functional::relu(fc1->forward(x));x = torch::nn::functional::relu(fc2->forward(x));auto y = fc3->forward(x);return y;}
};

数据转换

这里比较关键的是:cv::Mat和torch::Tensor的相互转化,如果不考虑可视化的话,也可以在加载CIFAR10数据集的时候,在读取字节流的时候顺手把它转化为tensor也是可以的。

auto tensor = torch::from_blob(image.second.data, { image.second.rows, image.second.cols, 3 }, torch::kUInt8).clone();
tensor = tensor.permute({ 2, 0, 1 });
tensor = tensor.unsqueeze(0);
tensor = tensor.to(torch::kF32) / 255.;
tensor = tensor.sub(0.5).div(0.5);

这里的image是一个cv::Mat格式的数据,使用from_blob方法进行转化,并且使用clone方法进行拷贝防止在转化完成之后,这部分内存空间被opencv自动回收,需要显式调用unsqueeze添加batch的维度。

完整代码

#include <torch/torch.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/pooling.h>
#include <torch/nn/functional.h>
#include <torch/optim/optimizer.h>
#include <opencv2/opencv.hpp>
#include <fstream>
#define IMAGE_HEIGHT 32
#define IMAGE_WIDTH 32
std::string CIFAR_data_path = "P:/torch-cpp/dataSet/cifar-10-batches-bin/data_batch_1.bin";
void load_CIFAR(std::vector<std::pair<char, cv::Mat> >& images)
{std::ifstream fs(CIFAR_data_path.c_str(), std::ios::binary);if (fs.is_open()){while (!fs.eof()){char label;std::vector<uchar> image;image.resize(3 * IMAGE_HEIGHT * IMAGE_WIDTH);fs.read(&label, 1);fs.read(reinterpret_cast<char*>(image.data()), 3 * IMAGE_HEIGHT * IMAGE_WIDTH);cv::Mat image_cv(IMAGE_WIDTH, IMAGE_HEIGHT, CV_8UC3);
#pragma omp forfor (int i = 0; i < IMAGE_HEIGHT; i++){for (int j = 0; j < IMAGE_WIDTH; j++){image_cv.at<cv::Vec3b>(i, j)[0] = image[i * IMAGE_WIDTH + j];image_cv.at<cv::Vec3b>(i, j)[1] = image[i * IMAGE_WIDTH + j + IMAGE_HEIGHT * IMAGE_WIDTH];image_cv.at<cv::Vec3b>(i, j)[2] = image[i * IMAGE_WIDTH + j + IMAGE_HEIGHT * IMAGE_WIDTH * 2];}}images.push_back({ label, image_cv });}fs.close();}
}
class Classifier : public torch::nn::Module
{
public:torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };torch::nn::Conv2d conv1{ nullptr }, conv2{nullptr};torch::nn::MaxPool2d pool{ nullptr };Classifier() : Module(),fc1(torch::nn::Linear(torch::nn::LinearOptions(16 * 5 * 5, 120))),fc2(torch::nn::Linear(torch::nn::LinearOptions(120, 84))),fc3(torch::nn::Linear(torch::nn::LinearOptions(84, 10))),conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5))),conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5))),pool(torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))){register_module("conv1", conv1);register_module("conv2", conv2);register_module("pool", pool);register_module("fc1", fc1);register_module("fc2", fc2);register_module("fc3", fc3);}~Classifier(){}torch::Tensor forward(torch::Tensor x){x = pool->forward(torch::nn::functional::relu(conv1->forward(x)));x = pool->forward(torch::nn::functional::relu(conv2->forward(x)));x = torch::flatten(x, 1);x = torch::nn::functional::relu(fc1->forward(x));x = torch::nn::functional::relu(fc2->forward(x));auto y = fc3->forward(x);return y;}
};
int main()
{std::vector<std::pair<char, cv::Mat>> images;load_CIFAR(images);Classifier net;net.train();torch::nn::CrossEntropyLoss criterion;auto optimizer = torch::optim::SGD(net.parameters(), torch::optim::SGDOptions(0.01));int batch_size = 16;int epoches = 100;for (int epoch = 0; epoch < epoches; epoch++){double running_loss = 0.;for (int i = 0; i < images.size(); i+= batch_size){std::vector<torch::Tensor> batch_images;std::vector<torch::Tensor> batch_labels;
#pragma omp forfor (int batch = 0; batch < batch_size && batch + i < images.size(); batch++){auto image = images[i + batch];auto tensor = torch::from_blob(image.second.data, { image.second.rows, image.second.cols, 3 }, torch::kUInt8).clone();tensor = tensor.permute({ 2, 0, 1 });tensor = tensor.unsqueeze(0);tensor = tensor.to(torch::kF32) / 255.;tensor = tensor.sub(0.5).div(0.5);auto target = torch::tensor({ static_cast<int>(image.first) }, torch::kLong);
#pragma omp critical{batch_images.push_back(tensor);batch_labels.push_back(target);}}// 将小批次的图像和标签堆叠成一个大批次auto input_batch = torch::cat(batch_images, 0);auto label_batch = torch::cat(batch_labels, 0);auto predict = net.forward(input_batch);optimizer.zero_grad();auto loss = criterion(predict, label_batch);loss.backward();optimizer.step();running_loss += loss.item<double>();std::cout << "Epoch [" << (epoch + 1) << "/" << epoches << "] - Loss: " << running_loss / images.size() << std::endl;}}cv::Mat image_show;std::vector<cv::Mat> image_show_vector;for (int i = 0; i < 5; i++){std::vector<cv::Mat> image_row_vector;for (int j = 0; j < 5; j++){auto image = images[i * 5 + j];auto tensor = torch::from_blob(image.second.data, { image.second.rows, image.second.cols, 3 }, torch::kUInt8).clone();tensor = tensor.permute({ 2, 0, 1 });tensor = tensor.unsqueeze(0);tensor = tensor.to(torch::kF32) / 255.;tensor = tensor.sub(0.5).div(0.5);auto predict = net.forward(tensor);int label = torch::argmax(predict, 1).item<int>();cv::Mat image_resize;cv::resize(image.second, image_resize, cv::Size(), 8, 8, cv::INTER_CUBIC);cv::putText(image_resize, "ground true: " + std::to_string(image.first), { 10, 30 }, cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 0, 255), 3);cv::putText(image_resize, "pridict: " + std::to_string(label), { 10, 80 }, cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 0, 255), 3);image_row_vector.push_back(image_resize);}cv::Mat image_temp;cv::hconcat(image_row_vector, image_temp);image_show_vector.push_back(image_temp);}cv::vconcat(image_show_vector, image_show);cv::imwrite("validation.png", image_show);/*cv::imshow("image", image_show);cv::waitKey(0);*/return 0;
}

结果

请添加图片描述

损失函数区别

day01使用的是MSELoss(均方根误差),基本公式
M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE = \frac{1}{n}\sum_{i =1 }^n(y_i-\hat{y}_i)^2 MSE=n1i=1n(yiy^i)2
描述的是回归任务中的预测值和真实值之间的差的平方均值。
在这里是分类任务,使用的是CrossEntropyLoss(交叉熵损失函数),假设网络的输出是一个向量
p ^ c = [ p ^ 1 c p ^ 2 c ⋯ p ^ n c ] \hat{p}_c = \left[\begin{matrix}\hat{p}_{1c}&\hat{p}_{2c}\cdots \hat{p}_{nc}\end{matrix}\right] p^c=[p^1cp^2cp^nc]
这个向量的每一个元素代表对于每一个标签的预测的概率大小, y c y_c yc是真实标签,是一个标量,那么交叉熵损失函数定义为
C r o s s E n t r o p y = − 1 n ∑ i = 1 n ∑ c = 1 C y i c l o g ( p ^ i c ) CrossEntropy = -\frac{1}{n}\sum_{i=1}^n\sum_{c=1}^Cy_{ic}log(\hat{p}_{ic}) CrossEntropy=n1i=1nc=1Cyiclog(p^ic)

这篇关于libtorch---day02[第一个分类器]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1126382

相关文章

好题——hdu2522(小数问题:求1/n的第一个循环节)

好喜欢这题,第一次做小数问题,一开始真心没思路,然后参考了网上的一些资料。 知识点***********************************无限不循环小数即无理数,不能写作两整数之比*****************************(一开始没想到,小学没学好) 此题1/n肯定是一个有限循环小数,了解这些后就能做此题了。 按照除法的机制,用一个函数表示出来就可以了,代码如下

理解分类器(linear)为什么可以做语义方向的指导?(解纠缠)

Attribute Manipulation(属性编辑)、disentanglement(解纠缠)常用的两种做法:线性探针和PCA_disentanglement和alignment-CSDN博客 在解纠缠的过程中,有一种非常简单的方法来引导G向某个方向进行生成,然后我们通过向不同的方向进行行走,那么就会得到这个属性上的图像。那么你利用多个方向进行生成,便得到了各种方向的图像,每个方向对应了很多

Spring Roo 实站( 一 )部署安装 第一个示例程序

转自:http://blog.csdn.net/jun55xiu/article/details/9380213 一:安装 注:可以参与官网spring-roo: static.springsource.org/spring-roo/reference/html/intro.html#intro-exploring-sampleROO_OPTS http://stati

使用gradle做第一个java项目

涉及到的任务如下: assemble任务会编译程序中的源代码,并打包生成Jar文件,这个任务不执行单元测试。 Total time: 5.581 secs E:\workspace\Test>gradle assemble :compileJava :processResources UP-TO-DATE :classes :findMainClass :jar :b

vue2实践:第一个非正规的自定义组件-动态表单对话框

前言 vue一个很重要的概念就是组件,作为一个没有经历过前几代前端开发的我来说,不太能理解它所带来的“进步”,但是,将它与后端c++、java类比,我感觉,组件就像是这些语言中的类和对象的概念,通过封装好的组件(类),可以通过挂载的方式,非常方便的调用其提供的功能,而不必重新写一遍实现逻辑。 我们常用的element UI就是由饿了么所提供的组件库,但是在项目开发中,我们可能还需要额外地定义一

SpringMVC的第一个案例 Helloword 步骤

第一步:web.xml配置 <?xml version="1.0" encoding="UTF-8"?> <web-app version="2.5" xmlns="http://java.sun.com/xml/ns/javaee" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocati

我的第一次份实习工作-iOS实习生-第一个月

实习时间:2015-08-20 到 2015-12-25  实习公司;福建天棣互联有限公司 实习岗位:iOS开发实习生 第一个月: 第一天来公司,前台报道后,人资带我去我工作的地方。到了那,就由一个组长带我,当时还没有我的办公桌,组长在第三排给我找了一个位置,擦了下桌子,把旁边的准备的电脑帮我装了下,因为学的是iOS,实习生就只能用黑苹果了,这是我实习用的电脑。 帮我装了一下电脑后,开机

从零开始:打造你的第一个餐厅点餐小程序

目录 1 为什么选择点餐小程序2 会有哪些功能2.1 顾客端2.2 服务员端2.3 后厨端2.4 收银端2.5 管理员(老板)端 3 开发工具选择4 你将获得什么让我们开始吧 最近,有不少粉丝咨询,有没有系统的低代码学习教程呀?为啥你的教程有的刚看的提起兴趣,怎么突然就中断了。有没有系统的视频学习教程呀,你是不是还有压箱底的好宝贝,没开放给我们看呀。 还真不是,压箱底的好宝贝已

javaweb-day02-2(00:40:06 XML 解析 - Dom4j解析开发包)

导入dom4j开发包:dom4j-1.6.1.jar   在工程下建一个文件夹lib,将dom4j-1.6.1.jar拷到里边。右键add to build path。  dom4j-1.6.1\lib文件夹下还有一些jar包,是开发过程中dom4j所需要依赖的jar包,如开发过程中报错,则需导入。   用dom4j怎么做呢? 只要是开源jar包提供给你的时候,它会在开源包里面提供

javaweb-day02-2(XML 解析 - Jaxp的sax方式解析)

Jaxp解析开发包 Sax解析方式只能做查询: Sax解析方式和DOM解析方式的区别:     在使用 DOM 解析 XML 文档时,需要读取整个 XML文档,在内存中构架代表整个DOM 树的Doucment对象,从而再对XML文档进行操作。此种情况下,如果XML 文档特别大,就会消耗计算机的大量内存,并且容易导致内存溢出。  SAX解析允许在读取文档的时候,即对文档进行处