本文主要是介绍【torch杂记】torchvision.transforms中的ToTensor和Normalize,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
torchvision.transforms中的ToTensor和Normalize
文章目录
- torchvision.transforms中的ToTensor和Normalize
- 参考
- transforms.ToTensor()
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- 二者的调用_Compose
参考
- 关于transforms.Normalize()函数
- PyTorch中的contiguous
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.ToTensor()
-
能够把灰度范围从0-255变换到0-1之间
-
从源码的角度看,调用的ToTensor的时候是调用这个class的__call__方法,然后ToTensor的call是调用了F的to_tensor方法,F是functional.py
-
class ToTensor(object):"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)or if the numpy.ndarray has dtype = np.uint8In the other cases, tensors are returned without scaling."""def __call__(self, pic):"""Args:pic (PIL Image or numpy.ndarray): Image to be converted to tensor.Returns:Tensor: Converted image."""# 这里会调用functional中的to_tensor方法return F.to_tensor(pic)def __repr__(self):return self.__class__.__name__ + '()'
-
-
to_tensor源码:
-
def to_tensor(pic): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. See ``ToTensor`` for more details. Args:pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns:Tensor: Converted image. """ if not(_is_pil_image(pic) or _is_numpy_image(pic)):raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))if isinstance(pic, np.ndarray):# handle numpy arrayif pic.ndim == 2:pic = pic[:, :, None]img = torch.from_numpy(pic.transpose((2, 0, 1)))# backward compatibilityif isinstance(img, torch.ByteTensor):return img.float().div(255)else:return imgif accimage is not None and isinstance(pic, accimage.Image):nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)pic.copyto(nppic)return torch.from_numpy(nppic)# handle PIL Image if pic.mode == 'I':img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16':img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F':img = torch.from_numpy(np.array(pic, np.float32, copy=False)) elif pic.mode == '1':img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else:img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr':nchannel = 3 elif pic.mode == 'I;16':nchannel = 1 else:nchannel = len(pic.mode) img = img.view(pic.size[1], pic.size[0], nchannel) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor):return img.float().div(255) else:return img
-
-
对于RGB图片,to_tensor它主要做了四件事
-
1、
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
将其转换成uint8类型的tensor,且是1维的 -
2、
nchannel = len(pic.mode)
设定通道数,img = img.view(pic.size[1], pic.size[0], nchannel)
将img重新变形成3维的tensor -
3、
img = img.transpose(0, 1).transpose(0, 2).contiguous()
,transpose函数的作用是交换坐标轴,类似二维中的转置的概念,而contiguous函数的作用是让tensor的底层变得连续(相当于改变底层的一维数组)- (
torch.view
等方法操作需要连续的Tensor。transpose、permute 操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。torch.view
方法约定了不修改数组本身,只是使用新的形状查看数据。如果我们在 transpose、permute 操作后执行 view,Pytorch 会报错)
- (
-
4、
return img.float().div(255)
,转成float再除以255(因为之前是uint8,区间是0-255,因此除以255之后就变成0-1区间) -
至此他就实现了将pic图片数组转成0-1区间的tensor
-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
-
则把0-1变换到(-1,1)
-
class Normalize(object):"""Normalize a tensor image with mean and standard deviation.Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transformwill normalize each channel of the input ``torch.*Tensor`` i.e.``input[channel] = (input[channel] - mean[channel]) / std[channel]``.. note::This transform acts out of place, i.e., it does not mutates the input tensor.Args:mean (sequence): Sequence of means for each channel.std (sequence): Sequence of standard deviations for each channel."""def __init__(self, mean, std, inplace=False):self.mean = meanself.std = stdself.inplace = inplacedef __call__(self, tensor):"""Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.Returns:Tensor: Normalized Tensor image."""return F.normalize(tensor, self.mean, self.std, self.inplace)def __repr__(self):return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
-
和ToTensor一样,调用的时候是调用call方法,此时输入的tensor如下图所示,维0-1的范围
-
接下来是调用functional下的normalize方法,源码如下:
-
def normalize(tensor, mean, std, inplace=False):"""Normalize a tensor image with mean and standard deviation... note::This transform acts out of place by default, i.e., it does not mutates the input tensor.See :class:`~torchvision.transforms.Normalize` for more details.Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.mean (sequence): Sequence of means for each channel.std (sequence): Sequence of standard deviations for each channely.Returns:Tensor: Normalized Tensor image."""if not _is_tensor_image(tensor):raise TypeError('tensor is not a torch image.')if not inplace:tensor = tensor.clone()mean = torch.tensor(mean, dtype=torch.float32)std = torch.tensor(std, dtype=torch.float32)tensor.sub_(mean[:, None, None]).div_(std[:, None, None])return tensor
-
分析如下
- 由于inplace默认为false,因此,首先会clone下tensor
- mean和tensor都是(0.5,0.5,0.5)的元组,然后通过torch.tensor变成float32的一维tensor
- 然后这句就开始变换了
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
- mean[:, None, None]和std[:, None, None]将其变形成(3,1,1)的tensor然后相减,(0~1-0.5)/0.5=-1~1
- 故此实现转换
-
二者的调用_Compose
-
一般长这样
-
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-
-
然后她就说传入一个数组,然后call函数遍历调用即可,下面是它的源码
-
class Compose(object):"""Composes several transforms together.Args:transforms (list of ``Transform`` objects): list of transforms to compose.Example:>>> transforms.Compose([>>> transforms.CenterCrop(10),>>> transforms.ToTensor(),>>> ])"""def __init__(self, transforms):self.transforms = transformsdef __call__(self, img):for t in self.transforms:img = t(img)return imgdef __repr__(self):format_string = self.__class__.__name__ + '('for t in self.transforms:format_string += '\n'format_string += ' {0}'.format(t)format_string += '\n)'return format_string
-
这篇关于【torch杂记】torchvision.transforms中的ToTensor和Normalize的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!