本文主要是介绍解决 torch.cat(): input types can‘t be cast to the desired output type Byte,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
最近使用 U2Net 训练模型的时候,遇到了下面的错误:
RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte
错误堆栈信息如下
Original Traceback (most recent call last): File "/datapython3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) ^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) ^^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 268, in default_collate return collate(batch, collate_fn_map=default_collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 165, in collate_tensor_fn return torch.stack(batch, 0, out=out) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte
原因:个人猜想是多个 worker 在一起工作引起的并发问题。
解决方法一:
在构建 DataLoader 实例的时候,把 workers 设置为 0 即可。
缺点:会导致训练速度变慢
train_dataset = SalObjDataset(img_name_list=train_img_name_list,lbl_name_list=train_label_name_list,transform=transforms.Compose([RescaleT(320),# RandomCrop(288),ToTensorLab(flag=0)]))
train_dataset.auto_collation = True
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
解决方法二:
修改 pytorch 里面的代码。
在上面的堆栈中,显示 /datapython3.11/site-packages/torch/utils/data/_utils/collate.py:165 报错了
我们打开 collate.py 文件,找到 collate_tensor_fn 这个函数,把 if 语句的内容注释掉就可以了
缺点:需要修改pytorch的代码,会增加多一次内存拷贝
def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):elem = batch[0]out = None# if torch.utils.data.get_worker_info() is not None:# # If we're in a background process, concatenate directly into a# # shared memory tensor to avoid an extra copy# numel = sum(x.numel() for x in batch)# storage = elem._typed_storage()._new_shared(numel, device=elem.device)# out = elem.new(storage).resize_(len(batch), *list(elem.size()))return torch.stack(batch, 0, out=out)
.
.
这篇关于解决 torch.cat(): input types can‘t be cast to the desired output type Byte的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!