Torch截断一部分后是否能梯度回传

2023-10-09 22:30

本文主要是介绍Torch截断一部分后是否能梯度回传,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


import torchfrom torch import optim 
import torch.nn as nnclass g(nn.Module):def __init__(self):super(g, self).__init__()self.k = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, padding=0, bias=False)def forward(self, z):return self.k(z)c = 2
h = 5
w = 5
z = torch.rand( (1,c , h , w)).float().view(1, c, h, w)*100
z.requires_grad = True
k = g()optim = optim.Adam(k.parameters(), lr=1)
optim.zero_grad()
r = k(z)
r= r[:,:,:3,:3]
r = r.sum()
loss = (r - 1) * (r - 1)for name,v in k.named_parameters():print(name,v) 
print(z)
print("*********************")loss.backward()
optim.step()
for name,v in k.named_parameters():print(name,v) 
print(z)

输出:


tensor([[[[-0.0464]],

         [[ 0.4256]]]], requires_grad=True)
tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
          [27.1530,  5.7923, 23.9614, 59.5419,  3.5597],
          [69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
          [62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
          [98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],

         [[19.3582,  2.4226, 47.2068, 20.1124, 31.9324],
          [23.4966,  5.0654, 12.4682, 35.3092, 90.3394],
          [ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
          [49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
          [17.8333, 18.6162, 44.9271, 44.8790,  3.6609]]]], requires_grad=True)
*********************
k.weight Parameter containing:
tensor([[[[-1.0464]],

         [[-0.5744]]]], requires_grad=True)
tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
          [27.1530,  5.7923, 23.9614, 59.5419,  3.5597],
          [69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
          [62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
          [98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],

         [[19.3582,  2.4226, 47.2068, 20.1124, 31.9324],
          [23.4966,  5.0654, 12.4682, 35.3092, 90.3394],
          [ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
          [49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
          [17.8333, 18.6162, 44.9271, 44.8790,  3.6609]]]], requires_grad=True)

这篇关于Torch截断一部分后是否能梯度回传的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/176009

相关文章

Codeforces Round #113 (Div. 2) B 判断多边形是否在凸包内

题目点击打开链接 凸多边形A, 多边形B, 判断B是否严格在A内。  注意AB有重点 。  将A,B上的点合在一起求凸包,如果凸包上的点是B的某个点,则B肯定不在A内。 或者说B上的某点在凸包的边上则也说明B不严格在A里面。 这个处理有个巧妙的方法,只需在求凸包的时候, <=  改成< 也就是说凸包一条边上的所有点都重复点都记录在凸包里面了。 另外不能去重点。 int

easyui同时验证账户格式和ajax是否存在

accountName: {validator: function (value, param) {if (!/^[a-zA-Z][a-zA-Z0-9_]{3,15}$/i.test(value)) {$.fn.validatebox.defaults.rules.accountName.message = '账户名称不合法(字母开头,允许4-16字节,允许字母数字下划线)';return fal

【408DS算法题】039进阶-判断图中路径是否存在

Index 题目分析实现总结 题目 对于给定的图G,设计函数实现判断G中是否含有从start结点到stop结点的路径。 分析实现 对于图的路径的存在性判断,有两种做法:(本文的实现均基于邻接矩阵存储方式的图) 1.图的BFS BFS的思路相对比较直观——从起始结点出发进行层次遍历,遍历过程中遇到结点i就表示存在路径start->i,故只需判断每个结点i是否就是stop

linux 判断某个命令是否安装

linux 判断某个命令是否安装 if ! [ -x "$(command -v git)" ]; thenecho 'Error: git is not installed.' >&2exit 1fi

✨机器学习笔记(二)—— 线性回归、代价函数、梯度下降

1️⃣线性回归(linear regression) f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b​(x)=wx+b 🎈A linear regression model predicting house prices: 如图是机器学习通过监督学习运用线性回归模型来预测房价的例子,当房屋大小为1250 f e e t 2 feet^

AI学习指南深度学习篇-带动量的随机梯度下降法的基本原理

AI学习指南深度学习篇——带动量的随机梯度下降法的基本原理 引言 在深度学习中,优化算法被广泛应用于训练神经网络模型。随机梯度下降法(SGD)是最常用的优化算法之一,但单独使用SGD在收敛速度和稳定性方面存在一些问题。为了应对这些挑战,动量法应运而生。本文将详细介绍动量法的原理,包括动量的概念、指数加权移动平均、参数更新等内容,最后通过实际示例展示动量如何帮助SGD在参数更新过程中平稳地前进。

AI学习指南深度学习篇-带动量的随机梯度下降法简介

AI学习指南深度学习篇 - 带动量的随机梯度下降法简介 引言 在深度学习的广阔领域中,优化算法扮演着至关重要的角色。它们不仅决定了模型训练的效率,还直接影响到模型的最终表现之一。随着神经网络模型的不断深化和复杂化,传统的优化算法在许多领域逐渐暴露出其不足之处。带动量的随机梯度下降法(Momentum SGD)应运而生,并被广泛应用于各类深度学习模型中。 在本篇文章中,我们将深入探讨带动量的随

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

如何判断一个数组中是否包含一个字符或字符串

第一种方法:遍历数组 String[] arr1 = {"1","2","3","4","6","7"}; for (int i = 0; i < arr1.length; i++) { if("5".equals(arr1[i])) { System.out.println("包含"); }else { System.out.println("不包含"); } } 第二种方法:先把数组