本文主要是介绍记录使用pytorch训练crnn,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
工程来源:
https://github.com/WenmuZhou/PytorchOCR?tab=readme-ov-file#train
基本数据准备和配置和paddleOCR是一样的;记录一下使用时遇到的问题。
1.环境
我使用的是:172.31.50.201:5000/algorithm/pytorch-1.11.0-cuda11.3-cudnn8-devel-arcface:v2
然后陆续按照要求安装了库:
#pip install imgaug -i https://pypi.tuna.tsinghua.edu.cn/simple
#pip install pyclipper -i https://pypi.tuna.tsinghua.edu.cn/simple
#pip install lmdb -i https://pypi.tuna.tsinghua.edu.cn/simple
#pip install rapidfuzz -i https://pypi.tuna.tsinghua.edu.cn/simple
2.训练时遇到的问题:训练一开始就NAN,使用小数据集时,acc一直为0:
解决办法是修改了CTCloss初始化:
在class CTCLoss(nn.Module)中
self.loss_func = nn.CTCLoss(blank=0, reduction='none',zero_infinity=True)
遇到问题时给的一些好的参考:
[深度学习][pytorch][原创]crnn在高版本pytorch上训练loss为nan解决办法_crnn中train loss: nan-CSDN博客 关于pytorch自带的CTCloss使用时的注意事项_pytorch ctc-CSDN博客
https://zhuanlan.zhihu.com/p/67415439
然后就没有报错了
3.加载预训练模型代码修改
def load_pretrained_params(model, pretrained_model):# checkpoint = torch.load(pretrained_model, map_location=torch.device('cpu'))# model.load_state_dict(checkpoint['state_dict'], strict=False)backbone_dict = model.state_dict()pretrained_dict = torch.load(pretrained_model, map_location=torch.device('cpu'))pretrained_dict_backbone_ = {}for k, v in pretrained_dict['state_dict'].items():k_ = k.replace('module.', '')if k_ in backbone_dict and backbone_dict[k_].size() == v.size():pretrained_dict_backbone_[k_] = velse:print(k_, backbone_dict[k_].size(), v.size())backbone_dict.update(pretrained_dict_backbone_)model.load_state_dict(backbone_dict)
这篇关于记录使用pytorch训练crnn的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!