本文主要是介绍Pytorch中的scatter_函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
(1). scatter_函数详细描述如下:
scatter_(input,dim,index,value)
将value对应的值按照index确定的索引写入input张量中,其中索引是根据给定的dim(维度)来确定的。
"""
Args:
input:要进行scatter_填充的tensor
dim:在input张量进行scatter_填充的维度
index:input对应dim的填充索引,要小于对应填充维度的长度,且index维度要与input张量维度一致
value:填充值
"""
(2). 代码实现
import torch
label = torch.zeros(2, 4)
print("label:",label)
label.scatter_(dim=1,index=torch.LongTensor([[2],[3]]),value=1)
print("new_label: ",label)
显示结果:
label: tensor([[0., 0., 0., 0.],[0., 0., 0., 0.]])
new_label: tensor([[0., 0., 1., 0.],[0., 0., 0., 1.]])
这篇关于Pytorch中的scatter_函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!