本文主要是介绍RuntimeError: output with shape [1, 28, 28] doesnt match the broadcast shape [3, 28, 28],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pytorch执行MNIST源码
# Import things like usual%matplotlib inline
%config InlineBackend.figure_format = 'retina'import numpy as np
import torchimport helperimport matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
# Download and load the training data
trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# Download and load the test data
testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)dataiter = iter(trainloader)
images, labels = dataiter.next()#报错如下RuntimeError Traceback (most recent call last)
<ipython-input-22-840309f5aa1d> in <module>1 dataiter = iter(trainloader)
----> 2 images, labels = dataiter.next()C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)558 if self.num_workers == 0: # same-process loading559 indices = next(self.sample_iter) # may raise StopIteration
--> 560 batch = self.collate_fn([self.dataset[i] for i in indices])561 if self.pin_memory:562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in <listcomp>(.0)558 if self.num_workers == 0: # same-process loading559 indices = next(self.sample_iter) # may raise StopIteration
--> 560 batch = self.collate_fn([self.dataset[i] for i in indices])561 if self.pin_memory:562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index)93 94 if self.transform is not None:
---> 95 img = self.transform(img)96 97 if self.target_transform is not None:C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)59 def __call__(self, img):60 for t in self.transforms:
---> 61 img = t(img)62 return img63 C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, tensor)162 Tensor: Normalized Tensor image.163 """
--> 164 return F.normalize(tensor, self.mean, self.std, self.inplace)165 166 def __repr__(self):C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace)206 mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)207 std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
--> 208 tensor.sub_(mean[:, None, None]).div_(std[:, None, None])209 return tensor210 RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
解决办法:
将三通道的标准化改为1通道的,因为使用的图片集是1通道的,如下
#transform = transforms.Compose([transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,)),
])
#解决
https://blog.csdn.net/weixin_43159148/article/details/88778371
这篇关于RuntimeError: output with shape [1, 28, 28] doesnt match the broadcast shape [3, 28, 28]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!