libtorch---day03[自定义导数]

2024-09-02 14:52

本文主要是介绍libtorch---day03[自定义导数],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考pytorch。

背景

希望使用勒让德多项式拟合一个周期内的正弦函数。
真值: y = s i n ( x ) , x ∈ [ − π , π ] y=sin(x),x\in\left[-\pi,\pi\right] y=sin(x),x[π,π]

torch::Tensor x = torch::linspace(-M_PI, M_PI, 2000, torch::kFloat);
torch::Tensor y = torch::sin(x);

预测值是 n = 3 n=3 n=3的勒让德多多项式: y ^ = a + b × P 3 ( c + d x ) \hat{y} = a+b\times P_3(c+dx) y^=a+b×P3(c+dx),其中 P 3 ( x ) = 1 2 ( 5 x 3 − 3 x ) P_3(x) = \frac{1}{2}(5x^3-3x) P3(x)=21(5x33x)

构造自动求导类

torch提供了一种可以让开发者自主定义前向传播和后向求导的机制:

1、写一个类,继承torch::autograd::Function
2、在类中定义静态的forwardbackward函数,必须是静态的,这样在调用torch::autograd::Function::applytorch::autograd::Function::backward的时候,会自动调用上述两个静态函数;

struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};

关键点

  • 必须显式调用**ctx->save_for_backward({ input });保存节点信息、调用auto saved = ctx->get_saved_variables();**获取保存的节点信息;
  • forward函数计算的是预测值,这个和认知里的forward的功能相同;
  • backward函数的输入是grad_output,是损失项关于输出的梯度 ∂ L ∂ y \frac{\partial L}{\partial y} yL,而backward计算的是损失函数关于输入的梯度 ∂ L ∂ x \frac{\partial L}{\partial x} xL,因此需要计算 ∂ L ∂ x = ∂ L ∂ y × ∂ y ∂ x \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\times \frac{\partial y}{\partial x} xL=yL×xy
  • 必须要注意backwardforward的参数列表必须固定;

全部代码

#include <torch/torch.h>
#include <iostream>
#include "matplotlibcpp.h"struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{auto data_ptr = x.data_ptr<float>();std::vector<float> x_vector(data_ptr, data_ptr + x.numel());data_ptr = y.data_ptr<float>();std::vector<float> y_vector(data_ptr, data_ptr + y.numel());data_ptr = predict.data_ptr<float>();std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"}});matplotlibcpp::plot(x_vector, y_vector, key_words);key_words["color"] = "red";key_words["linestyle"] = "--";key_words["label"] = "prediction";matplotlibcpp::plot(x_vector, predict_vector, key_words);matplotlibcpp::grid(true);matplotlibcpp::legend();matplotlibcpp::show();
}
int main()
{torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);torch::Tensor y = torch::sin(x);torch::Tensor a = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor b = torch::full({}, -1., torch::kFloat).set_requires_grad(true);torch::Tensor c = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor d = torch::full({}, 0.3, torch::kFloat).set_requires_grad(true);double learning_rate = 5e-6;torch::nn::MSELoss criterion;torch::optim::SGD optimizer({a, b, c, d}, torch::optim::SGDOptions(learning_rate));for (int i = 0; i < 2000; i++){auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;torch::Tensor loss = (predict - y).pow(2).sum();// auto loss = criterion(predict, y);loss.backward();optimizer.step();optimizer.zero_grad();std::cout << "iteration: " << i + 1 << "/2000" << ", loss: " << loss.item<double>() << std::endl;}auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;plot_tensor_xy_compare(x, y, predict);return 0;
}

结果

在这里插入图片描述

相应的nn模块

#include <torch/torch.h>
#include "matplotlibcpp.h"using namespace torch;
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{auto data_ptr = x.data_ptr<float>();std::vector<float> x_vector(data_ptr, data_ptr + x.numel());data_ptr = y.data_ptr<float>();std::vector<float> y_vector(data_ptr, data_ptr + y.numel());data_ptr = predict.data_ptr<float>();std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"} });matplotlibcpp::plot(x_vector, y_vector, key_words);key_words["color"] = "red";key_words["linestyle"] = "--";key_words["label"] = "prediction";matplotlibcpp::plot(x_vector, predict_vector, key_words);matplotlibcpp::grid(true);matplotlibcpp::legend();matplotlibcpp::show();
}
class auto_grad : public nn::Module
{
public:Tensor a, b, c, d;auto_grad() : a(torch::full({}, 0., kFloat).set_requires_grad(true)),b(torch::full({}, -1., kFloat).set_requires_grad(true)),c(torch::full({}, 0., kFloat).set_requires_grad(true)),d(torch::full({}, 0.3, kFloat).set_requires_grad(true)){register_parameter("a", a);register_parameter("b", b);register_parameter("c", c);register_parameter("d", d);}Tensor forward(Tensor input){auto P3 = c + d * input;return a + b * (0.5 * (5 * torch::pow(P3, 3) - 3 * P3));}
};
int main()
{auto_grad net;nn::MSELoss criterion;optim::SGDOptions opt(1e-5);opt.momentum(0.9);optim::SGD optim(net.parameters(), opt);torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);torch::Tensor y = torch::sin(x);int iteration = 1000;for (int i = 0; i < iteration; i++){auto predict = net.forward(x);auto loss = (predict - y).pow(2).sum();loss.backward();optim.step();optim.zero_grad();printf("[training iteration: %d/ %d, loss: %lf]\n", i +1, iteration, loss.item<double>());}auto predict = net.forward(x);plot_tensor_xy_compare(x, y, predict);return 0;
}

关键点

1、使用register_parameter显式注册参数;

这篇关于libtorch---day03[自定义导数]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1130274

相关文章

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

自定义类型:结构体(续)

目录 一. 结构体的内存对齐 1.1 为什么存在内存对齐? 1.2 修改默认对齐数 二. 结构体传参 三. 结构体实现位段 一. 结构体的内存对齐 在前面的文章里我们已经讲过一部分的内存对齐的知识,并举出了两个例子,我们再举出两个例子继续说明: struct S3{double a;int b;char c;};int mian(){printf("%zd\n",s

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

Oracle type (自定义类型的使用)

oracle - type   type定义: oracle中自定义数据类型 oracle中有基本的数据类型,如number,varchar2,date,numeric,float....但有时候我们需要特殊的格式, 如将name定义为(firstname,lastname)的形式,我们想把这个作为一个表的一列看待,这时候就要我们自己定义一个数据类型 格式 :create or repla

HTML5自定义属性对象Dataset

原文转自HTML5自定义属性对象Dataset简介 一、html5 自定义属性介绍 之前翻译的“你必须知道的28个HTML5特征、窍门和技术”一文中对于HTML5中自定义合法属性data-已经做过些介绍,就是在HTML5中我们可以使用data-前缀设置我们需要的自定义属性,来进行一些数据的存放,例如我们要在一个文字按钮上存放相对应的id: <a href="javascript:" d

一步一步将PlantUML类图导出为自定义格式的XMI文件

一步一步将PlantUML类图导出为自定义格式的XMI文件 说明: 首次发表日期:2024-09-08PlantUML官网: https://plantuml.com/zh/PlantUML命令行文档: https://plantuml.com/zh/command-line#6a26f548831e6a8cPlantUML XMI文档: https://plantuml.com/zh/xmi

argodb自定义函数读取hdfs文件的注意点,避免FileSystem已关闭异常

一、问题描述 一位同学反馈,他写的argo存过中调用了一个自定义函数,函数会加载hdfs上的一个文件,但有些节点会报FileSystem closed异常,同时有时任务会成功,有时会失败。 二、问题分析 argodb的计算引擎是基于spark的定制化引擎,对于自定义函数的调用跟hive on spark的是一致的。udf要通过反射生成实例,然后迭代调用evaluate。通过代码分析,udf在

鸿蒙开发中实现自定义弹窗 (CustomDialog)

效果图 #思路 创建带有 @CustomDialog 修饰的组件 ,并且在组件内部定义controller: CustomDialogController 实例化CustomDialogController,加载组件,open()-> 打开对话框 , close() -> 关闭对话框 #定义弹窗 (CustomDialog)是什么? CustomDialog是自定义弹窗,可用于广告、中

mybatis框架基础以及自定义插件开发

文章目录 框架概览框架预览MyBatis框架的核心组件MyBatis框架的工作原理MyBatis框架的配置MyBatis框架的最佳实践 自定义插件开发1. 添加依赖2. 创建插件类3. 配置插件4. 启动类中注册插件5. 测试插件 参考文献 框架概览 MyBatis是一个优秀的持久层框架,它支持自定义SQL、存储过程以及高级映射,为开发者提供了极大的灵活性和便利性。以下是关于M

vue2实践:第一个非正规的自定义组件-动态表单对话框

前言 vue一个很重要的概念就是组件,作为一个没有经历过前几代前端开发的我来说,不太能理解它所带来的“进步”,但是,将它与后端c++、java类比,我感觉,组件就像是这些语言中的类和对象的概念,通过封装好的组件(类),可以通过挂载的方式,非常方便的调用其提供的功能,而不必重新写一遍实现逻辑。 我们常用的element UI就是由饿了么所提供的组件库,但是在项目开发中,我们可能还需要额外地定义一