pytorch框架是如何autograd简易实现反向传播算法的?

2023-11-06 04:40

本文主要是介绍pytorch框架是如何autograd简易实现反向传播算法的?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

import torch
x = torch.randn(3,4,requires_grad=True) #对指定的X进行求导
b = torch.randn(3,4,requires_grad=True)
t = x+b
y=t.sum()
print(y)
y.backward() #执行反向传播
print(b.grad)print(x.requires_grad, b.requires_grad, t.requires_grad)   #自动计算t 自动指定为true

for example:

z=b+y

y=x*w

(求偏导逐层来做 链式求导)

import torch
x = torch.rand(1)
b = torch.rand(1, requires_grad=True)
w = torch.rand(1, requires_grad=True)
y = w * x
z = y + b
print(x.requires_grad, b.requires_grad, w.requires_grad, y.requires_grad)
# dw/dx 要先通过对y求导 故y也需要
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf)
#检测是否为leaf节点 True True True False False#反向传播的计算
z.backward(retain_graph=True) #如果对梯度不清零 会累加 实际训练模型时一般不需要累加
print(w.grad)
print(b.grad)

反向传播操作全部已经封装在函数内。

 

now,构建一个线性回归demo,have a try

import torch
import torch.nn as nn
import numpy as np
#大多时候对图像等进行数据读入都为np.array 需要转换
#先构建x、y
x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1, 1) #转为矩阵
print(x_train.shape)y_values = [2*i + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)
print(y_train.shape)#其实线性回归就是一个不加激活函数的全连接层
class LinearRegressionModel(nn.Module):def __init__(self, input_dim, output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim)def forward(self, x):out = self.linear(x)return out#y=2x+1
input_dim = 1
output_dim = 1model = LinearRegressionModel(input_dim, output_dim)
print(model)#指定好参数和损失函数进行训练
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#优化器SGD
criterion = nn.MSELoss()
#定义损失函数 回归函数->MSE#训练模型
for epoch in range(epochs):epoch += 1# 注意转行成tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)# 梯度要清零每一次迭代optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 返向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 5 == 0:print('epoch {}, loss {}'.format(epoch, loss.item()))#测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)#模型的保存与读取
torch.save(model.state_dict(), 'model.pkl')
model.load_state_dict(torch.load('model.pkl'))

 

 

 

 

这篇关于pytorch框架是如何autograd简易实现反向传播算法的?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/354522

相关文章

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过

Python中将嵌套列表扁平化的多种实现方法

《Python中将嵌套列表扁平化的多种实现方法》在Python编程中,我们常常会遇到需要将嵌套列表(即列表中包含列表)转换为一个一维的扁平列表的需求,本文将给大家介绍了多种实现这一目标的方法,需要的朋... 目录python中将嵌套列表扁平化的方法技术背景实现步骤1. 使用嵌套列表推导式2. 使用itert

Python使用pip工具实现包自动更新的多种方法

《Python使用pip工具实现包自动更新的多种方法》本文深入探讨了使用Python的pip工具实现包自动更新的各种方法和技术,我们将从基础概念开始,逐步介绍手动更新方法、自动化脚本编写、结合CI/C... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核