本文主要是介绍about batch[0].new(storage)的问题(VOT),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
最近在处理VOT数据集时,遇到了一个奇怪的问题,特此记录。
源代码如下:
def ltr_collate_stack1(batch):"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"elem_type = type(batch[0])if isinstance(batch[0], torch.Tensor):out = Noneif _check_use_shared_memory():# If we're in a background process, concatenate directly into a# shared memory tensor to avoid an extra copynumel = sum([x.numel() for x in batch])storage = batch[0].storage()._new_shared(numel)out = batch[0].new(storage)# print(batch.shape)# print(out.shape)return torch.stack(batch, 1, out=out)
按道理来说代码执行结束,out.shape和storage.shape一致,但是在pytorch2.1.1版本中,这两个却不一致,将代码修改如下即可正确运行:
def ltr_collate_stack1(batch):"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"elem_type = type(batch[0])if isinstance(batch[0], torch.Tensor):out = Noneif _check_use_shared_memory():# If we're in a background process, concatenate directly into a# shared memory tensor to avoid an extra copynumel = sum([x.numel() for x in batch])storage = batch[0].storage()._new_shared(numel)out = batch[0].new(storage).view(-1)# print(batch.shape)# print(out.shape)return torch.stack(batch, 1, out=out)
这篇关于about batch[0].new(storage)的问题(VOT)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!