张量(Tensor)维度尺寸对不齐(Expected size xx but got size xx for tensor)

2024-01-14 11:20

本文主要是介绍张量(Tensor)维度尺寸对不齐(Expected size xx but got size xx for tensor),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文以U-Net举例,演示如何解决张量(Tensor)维度尺寸对不齐的问题
U-Net的网络架构可以参考这篇文章:U-Net原理分析与代码解读
这是本文演示所用的U-Net代码

class UNet(nn.Module):def __init__(self):super(UNet, self).__init__()# 输入层self.input_conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)# 下采样部分self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2))self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2))self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2))self.down4 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2))# 桥接层 - 将输出通道数修改为1024,以便与down4_out拼接时通道数一致self.bridge = nn.Sequential(nn.Conv2d(1024, 1024, kernel_size=3, padding=1),nn.ReLU(inplace=True))# 上采样部分 - 调整每个上采样的第一个卷积层输入通道数self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),nn.Conv2d(2048, 1024, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(1024, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True))self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),nn.Conv2d(1024, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True))self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True))self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),nn.Conv2d(256, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True))# 输出层self.final_conv = nn.Conv2d(64, NUM_CLASSES, kernel_size=1)def forward(self, x):x = self.input_conv(x)  # 对原始输入进行处理down1_out = self.down1(x)down2_out = self.down2(down1_out)down3_out = self.down3(down2_out)down4_out = self.down4(down3_out)bridge_out = self.bridge(down4_out)up1_out = self.up1(torch.cat([bridge_out, down4_out], dim=1))up2_out = self.up2(torch.cat([up1_out, down3_out], dim=1))up3_out = self.up3(torch.cat([up2_out, down2_out], dim=1))up4_out = self.up4(torch.cat([up3_out, down1_out], dim=1))final_out = self.final_conv(up4_out)return torch.sigmoid(final_out)  # 因为是二分类问题,所以输出通过sigmoid激活

假设本文输入的图像是600乘以400像素的尺寸,那么对于本文U-Net代码所需的512乘以512像素的输入是肯定不匹配的。

一、图像缩放
既然输入图像的尺寸与网络所需输入的尺寸不符合,那就将输入图像的尺寸缩放到符合网络所需输入的尺寸就可以了。
在预处理函数中直接对原始图像进行缩放。
本文举例U-Net的所需输入是512乘以512像素,所以直接缩放为512乘以512像素

# 定义预处理函数
def get_transforms():# 对于图像的transformsimage_transforms_list = [transforms.Resize((512, 512)),  # 缩放至512x512transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 根据实际数据调整]image_transform = transforms.Compose(image_transforms_list)# 对于mask的transforms(不需要归一化)mask_transforms_list = [transforms.Resize((512, 512)),  # 缩放至512x512transforms.ToTensor()]mask_transform = transforms.Compose(mask_transforms_list)return image_transform, mask_transform

二、尺寸裁剪或尺寸填充
由于直接对原始图像进行缩放可能会对丢失一定的原始信息已经可能会扭曲一定的原始信息,所以更加建议使用尺寸裁剪或尺寸填充的方法。
尺寸裁剪或尺寸填充并不是在预处理函数中使用,而是在网络结构前向传播中使用,因为这时往往只需要改动几个像素点,对原始图像的改动较小。

定义尺寸裁剪或尺寸填充函数:
可以通过修改pad_value来决定用什么数值来填充(建议修改成背景的数值

def crop_or_pad_tensor(tensor, height_crop, width_crop, pad_value=0):'''裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。正数表示扩展(用0填充),负数表示裁剪。参数:tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1pad_value (float or int): 填充时使用的值,默认为0返回:cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量'''assert len(tensor.shape) == 4, '输入的tensor应为4维'# 获取原始的高度和宽度original_height, original_width = tensor.shape[2], tensor.shape[3]# 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0# 计算需要填充的数量(正值代表填充,负值代表不填充)pad_bottom = abs(height_crop) if height_crop > 0 else 0pad_right = abs(width_crop) if width_crop > 0 else 0# 先填充,再裁剪padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom), mode='constant', value=pad_value)# 在高度和宽度维度上进行裁剪(如果需要)if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:# 同时裁剪高度和宽度cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]elif height_to_remove_from_bottom > 0:# 只裁剪高度cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :]elif width_to_remove_from_right > 0:# 只裁剪宽度cropped_or_padded_tensor = padded_tensor[:, :, :, :-width_to_remove_from_right]else:# 不裁剪任何维度cropped_or_padded_tensor = padded_tensorreturn cropped_or_padded_tensor

在网络架构的forward方法中调用尺寸裁剪或尺寸填充函数:

# 对height裁剪一个像素,对width保持不变
crop_or_pad_tensor(up1_out, -1, 0)
# 对height保持不变,对width裁剪一个像素
crop_or_pad_tensor(up2_out, 0, -1)

在深度学习中,一个四维张量Tensor)通常代表的是批量图像数据,其维度排列通常是[batch_size, channels, height, width]
也就是Batch Size(批大小)Channels(通道数)Height(高度)、Width(宽度)
本文只讨论因Tensor中的heightwidth对不齐问题,batch_sizeChannels比较基础,就不提及了。

这篇关于张量(Tensor)维度尺寸对不齐(Expected size xx but got size xx for tensor)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/604984

相关文章

4-4.Andorid Camera 之简化编码模板(获取摄像头 ID、选择最优预览尺寸)

一、Camera 简化思路 在 Camera 的开发中,其实我们通常只关注打开相机、图像预览和关闭相机,其他的步骤我们不应该花费太多的精力 为此,应该提供一个工具类,它有处理相机的一些基本工具方法,包括获取摄像头 ID、选择最优预览尺寸以及打印相机参数信息 二、Camera 工具类 CameraIdResult.java public class CameraIdResult {

Windows11电脑上自带的画图软件修改照片大小(不裁剪尺寸的情况下)

针对一张图片,有时候上传的图片有大小限制,那么在这种情况下如何修改其大小呢,在不裁剪尺寸的情况下 步骤如下: 1.选定一张图片,右击->打开方式->画图,如下: 第二步:打开图片后,我们可以看到图片的大小为82.1kb,点击上面工具栏的“重设大小和倾斜”进行调整,如下: 第三步:修改水平和垂直的数字,此处我修改为分别都修改为50,然后保存,可以看到大小变成63.5kb,如下:

android.database.CursorIndexOutOfBoundsException: Index 5 requested, with a size of 5

描述: 01-02 00:13:43.380: E/flyLog:ChatManager(963): getUnreadChatGroupandroid.database.CursorIndexOutOfBoundsException: Index 5 requested, with a size of 5 01-02 00:13:43.380: E/flyLog:ChatManager(

java中的length与length()与size()

正确用法 Array.length int[] arr = {1,2,3};int x = arr.length;//arr.length = 3 String.length() String s = "123";int x = s.length();//s.length() = 3 Collection.size() ArrayList<Integer> list = n

nexus3.XX的下载安装和配置

nexus的下载: 官方地址:http://www.sonatype.com/download-oss-sonatype 百度网盘:http://pan.baidu.com/s/1eSBeid0 下载完成后,将nexus解压到指定位置:(如d:\nexus3) nexus的安装: 开始 -> 运行 -> cmd install  安装 uninstall 卸载 此

【DL--03】深度学习基本概念—张量

张量 TensorFlow中的中心数据单位是张量。张量由一组成形为任意数量的数组的原始值组成。张量的等级是其维数。以下是张量的一些例子: 3 # a rank 0 tensor; this is a scalar with shape [][1. ,2., 3.] # a rank 1 tensor; this is a vector with shape [3][[1., 2., 3.]

【CSS】尺寸单位

在 CSS 中,常见的尺寸单位有以下几种: 像素(px): 这是最常用的绝对单位。例如 width: 200px; 表示宽度为 200 像素。像素是固定的尺寸,不会随着屏幕分辨率或设备的不同而变化。 备注: 在不同的设备上,px 对应的物理尺寸并不固定。 对于电脑显示器,px 通常与屏幕的物理像素相对应,但这也会受到屏幕分辨率和缩放设置的影响。例如,在标准分辨率(通常为 96 DPI 左右

【vue使用Sass报错】启动项目报错 Syntax Error: SassError: expected selector

出现的问题 新项目启动的时候,提示: Syntax Error: SassError: expected selector 看了一下发现是sass使用样式穿透/deep/报的错 /deep/其实是已经过期的写法,某个版本之后就不支持了 但是我同事并没有出现同样的问题,不知道是为啥,也有可能是电脑(mac)的原因 解决办法 将 /deep/更换为::v-deep 但是这个项目是多人协作的,有

WebAPI(二)、DOM事件监听、事件对象event、事件流、事件委托、页面加载与滚动事件、页面尺寸事件

文章目录 一、 DOM事件1. 事件监听2. 事件类型(1)、鼠标事件(2)、焦点事件(3)、键盘事件(4)、文本事件 3. 事件对象(1)、获取事件对象(2)、事件对象常用属性 4. 环境对象 this5. 回调函数 二、 DOM事件进阶1. 事件流(1)、 捕获阶段(2)、 冒泡阶段(3)、 阻止冒泡(4) 、阻止元素默认行为(5) 、解绑事件 2. 事件委托3. 其他事件(1)、页面加

当网工,华为认证哪种适合我?四个维度来解惑

随着网络技术的不断进步,对网工的专业技能要求也越来越高。 在这种背景下,获得权威认证成为了提升个人技能、证明专业能力的重要途径。 华为,作为全球领先的ICT解决方案提供商,其认证项目在业界享有极高的声誉。 华为认证不仅涵盖了网络技术的各个方面,还根据不同的技能水平和职业发展阶段,提供了不同级别的认证,包括HCIA、HCIP、HCIE。 这些认证不仅有助于网络工程师提升自己的技术水平,也是企业在招聘