本文主要是介绍soft-argmax踩坑,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
最近在2D human pose estimation时需要用到soft-argmax,找了几个版本的函数,都有一个问题
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Long'
一、代码如下
def softargmax2d(input, beta=100):*_, h, w = input.shapeinput = beta*input.reshape(*_, h * w)input = F.softmax( input, dim=-1)indices_c, indices_r = np.meshgrid(np.linspace(0, 1, w),np.linspace(0, 1, h),indexing='xy')indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w)))indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w)))result_r = torch.sum((h - 1) * input * indices_r, dim=-1)result_c = torch.sum((w - 1) * input * indices_c, dim=-1)result = torch.stack([result_r, result_c], dim=-1)return result
二、测试如下
c=[[[[1,2,3],[4,5,16],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]]
c=torch.tensor(c)
print(c.size())
b=softargmax2d(c)
print(b)
三、结果如下
Traceback (most recent call last):
torch.Size([1, 2, 3, 3])
File "F:/pythonprogram/mon_repnet/wrm_model.py", line 202, in <module>
b=softargmax2d(c)
File "F:/pythonprogram/mon_repnet/wrm_model.py", line 172, in softargmax2d
input = F.softmax( input, dim=-1)
File "E:\software\python36\lib\site-packages\torch\nn\functional.py", line 1231, in softmax
ret = input.softmax(dim)
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Long'
找了很久你会发现很难搜到解决办法,其实……只要
四、修正如下
c=[[[[1,2,3],[4,5,16],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]]
c=torch.tensor(c).float()
print(c.size())
b=softargmax2d(c)
print(b)
转换输入的类型为float即可,额~~~~~~~
五、最终结果如下
torch.Size([1, 2, 3, 3])
tensor([[[1., 2.],
[2., 2.]]], dtype=torch.float64)
答案正确,最大值坐标分别为(1,2),(2,2)
这篇关于soft-argmax踩坑的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!