本文主要是介绍解决:RuntimeError: mat1 and mat2 shapes cannot be multiplied(单张图片输入情况,可参考),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在练习使用pytorch加载模型,识别图片时,出现了这一问题。
解决方法:使用torch.reshape()将输入数据格式改成与网络相符的格式。
详细过程:
报错代码:
import torch
import torchvision
from PIL import Imagefrom model import *image = Image.open("./img/dog.png")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(image)model = MyNet()
model.load_state_dict(torch.load("./testmodel/mynet_7.pth"))
output = model(image)
print(output)
其中我的model.py文件中的代码:
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
# @Time : 2023-09-22 15:57
# @Author : Kanbara
# @File : model.pyimport torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flattenclass MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.model = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(64*4*4, 64),Linear(64, 10))def forward(self, x):x = self.model(x)return x#测试该文件是否编译有问题
if __name__ == '__main__':mynet = MyNet()input = torch.ones([64, 3, 32, 32])output = mynet(input)print(output.shape)
model中的网络经过测试,本身不存在问题。
实际上,是输入图像尺寸少了一个参数batch_size导致。
print(image.shape)
>>torch.Size([3, 32, 32])
而根据网络设置,输入应有四个维度,第一个维度为batch_size。通过torch.reshape功能,添加代码:
image = torch.reshape(image, (1, 3, 32, 32))
即可解决。此时代码能够正常运行,修改后代码:
import torch
import torchvision
from PIL import Imagefrom model import *image = Image.open("./img/dog.png")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(image)model = MyNet()
model.load_state_dict(torch.load("./testmodel/mynet_7.pth"))
image = torch.reshape(image, (1, 3, 32, 32))
output = model(image)
print(output)
这篇关于解决:RuntimeError: mat1 and mat2 shapes cannot be multiplied(单张图片输入情况,可参考)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!