本文主要是介绍Torch截断一部分后是否能梯度回传,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
能
import torchfrom torch import optim
import torch.nn as nnclass g(nn.Module):def __init__(self):super(g, self).__init__()self.k = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, padding=0, bias=False)def forward(self, z):return self.k(z)c = 2
h = 5
w = 5
z = torch.rand( (1,c , h , w)).float().view(1, c, h, w)*100
z.requires_grad = True
k = g()optim = optim.Adam(k.parameters(), lr=1)
optim.zero_grad()
r = k(z)
r= r[:,:,:3,:3]
r = r.sum()
loss = (r - 1) * (r - 1)for name,v in k.named_parameters():print(name,v)
print(z)
print("*********************")loss.backward()
optim.step()
for name,v in k.named_parameters():print(name,v)
print(z)
输出:
tensor([[[[-0.0464]],
[[ 0.4256]]]], requires_grad=True)
tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
[27.1530, 5.7923, 23.9614, 59.5419, 3.5597],
[69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
[62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
[98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],
[[19.3582, 2.4226, 47.2068, 20.1124, 31.9324],
[23.4966, 5.0654, 12.4682, 35.3092, 90.3394],
[ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
[49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
[17.8333, 18.6162, 44.9271, 44.8790, 3.6609]]]], requires_grad=True)
*********************
k.weight Parameter containing:
tensor([[[[-1.0464]],
[[-0.5744]]]], requires_grad=True)
tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
[27.1530, 5.7923, 23.9614, 59.5419, 3.5597],
[69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
[62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
[98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],
[[19.3582, 2.4226, 47.2068, 20.1124, 31.9324],
[23.4966, 5.0654, 12.4682, 35.3092, 90.3394],
[ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
[49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
[17.8333, 18.6162, 44.9271, 44.8790, 3.6609]]]], requires_grad=True)
这篇关于Torch截断一部分后是否能梯度回传的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!