本文主要是介绍pytorch之torch.utils.model_zoo.load_url()在给定URL上加载Torch序列化对象,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
torch.utils.model_zoo.load_url(url, model_dir=None)
在给定URL上加载Torch序列化对象。
通俗点说,就是通过提供的.pth
文件的url地址来下载指定的.pth
文件【在pytorch中.pth文件就是模型的参数文件】
参数:
url (string)
- 要下载对象的URLmodel_dir (string, optional)
- 保存对象的目录
如果对象已经存在于model_dir
中,则将被反序列化并返回。【也就是后面所说的——已经下载好模型的情况】
参数详细说明:
(1)这里被model_zoo加载的权重遵循命名约定标准——url的文件名: 模型名-<SHA256取前n位>.pth
,其中<SHA256取前n位>
是文件内容的SHA256哈希的前八位或更多位数字。哈希用于确保唯一的名称并验证文件的内容。
The filename part of the URL should follow the naming convention
filename-<sha256>.ext
where<sha256>
is the first eight or more digits of the SHA256 hash of the contents of the file.
例如:
http://data.lip6.fr/cadene/pretrainedmodels/dpn131-7af84be88.pth
http://data.lip6.fr/cadene/pretrainedmodels/resnext101_32x4d-29e315fa.pth
http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth
如何查看SHA256?
- Window:
certutil -hashfile filename SHA256
- Linux:
sha256sum <filename>
补充——查看文件的三种哈希:
- Window:
certutil -hashfile filename MD5
certutil -hashfile filename SHA1
certutil -hashfile filename SHA256
- Linux:
md5sum <filename>
sha1sum <filename>
sha256sum <filename>
(2)model_dir
的默认值为$TORCH_HOME/models
,其中$TORCH_HOME
默认为~/.torch
。可以使用$TORCH_MODEL_ZOO
环境变量来覆盖默认目录。
我下载的默认路径是:
~/.cache/torch/checkpoints
可以通过设置全局变量export TORCH_HOME=/xx/xxx
即可修改下载的默认路径。修改后的路径为$TORCH_HOME/xx/xxx
例子:
weight_url='https://yjxiong.blob.core.windows.net/models/inceptionv3-cuhk-0e09b300b493bc74c.pth'
pretrained_dict = torch.utils.model_zoo.load_url(weight_url)
参考:官方文档
如果对象已经存在于model_dir
中,则将被反序列化并返回。
【已经下载好模型的情况】
如果你已经下载好模型了,那么可以通过torch.load(‘the/path/of/.pth’)
导入
因为torch.utils.model_zoo.load_url()
方法最后返回的时候也是用torch.load
接口封装成字典输出
例如我的代码(部分):
weight_url='https://yjxiong.blob.core.windows.net/models/inceptionv3-cuhk-0e09b300b493bc74c.pth'
pretrained_dict = torch.utils.model_zoo.load_url(weight_url)
self.load_state_dict(pretrained_dict)
因此不想下载的话可以用如下代码:
checkpoint=torch.load('~/.cache/torch/checkpoints/inceptionv3-cuhk-0e09b300b493bc74c.pth')
self.load_state_dict(checkpoint)
这篇关于pytorch之torch.utils.model_zoo.load_url()在给定URL上加载Torch序列化对象的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!