PyTorch|transforms.Normalize

2024-01-06 01:04

本文主要是介绍PyTorch|transforms.Normalize,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在训练时对图片数据进行归一化可以在梯度下降算法中更好的寻优,这是普遍认为的。那么PyTorch中的transforms.Normalize,究竟做了什么,这是应该知道的。

来看下面这个公式:x取自一组数据C, mean是这组数据的均值,而std则为标准差

x=(x-mean)/std

这也意味着,Normalize,简单来讲,就是按照此公式对输入数据进行更新,

来看这样一段代码:

import numpy as npList1=np.array([1,2,3,4])mean=np.mean(List1)std=np.std(List1)List2=(List1-mean)/std>>> List1array([1, 2, 3, 4])>>> List2array([-1.34164079, -0.4472136 ,  0.4472136 ,  1.34164079])

List1经过Normalize后变为List2

那么对于图片数据,Normalize具体是如何工作的呢?

假如我们有四张图片的数据,借用前面文章的数据导入方式,导入数据:

import osfrom PIL import Imageimport numpy as npfrom torchvision import transformsimport torchpath="E:\\3-10\\dogandcats\\source"IMG=[]filenames=[name for name in os.listdir(path)]for i,filename in enumerate(filenames):    img=Image.open(os.path.join(path,filename))    img=img.resize((28,28))#将图片像素改为28x28    img=np.array(img)#将图像数据转为numpy    img=torch.tensor(img)#将numpy转换为tensor张量    img=img.permute(2,0,1)#将H,W,C转换为C,H,W    IMG.append(img)#得到图片列表IMGEND=torch.stack([ig for ig in IMG],dim=0)#堆叠tensor
​​​​​​​
>>> IMGEND.size()torch.Size([4, 3, 28, 28])

四张图片数据已经成功导入,并且已经转换为张量

获得r,g,b三个通道的均值​​​​​​

>>> mean=torch.mean(IMGEND,dim=(0,2,3),keepdim=True)>>> meantensor([[[[160.8753]],
         [[149.3600]],
         [[126.5810]]]])

获得r,g,b三个通道的标准差:​​​​​​​

>>> std=torch.std(IMGEND,dim=(0,2,3),keepdim=True)>>> stdtensor([[[[61.7317]],
         [[65.0915]],
         [[84.2025]]]])

归一化:

process=transforms.Normalize([160.8753, 149.3600, 126.5810],[61.7317, 65.0915, 84.2025])>>> dataend1=process(IMGEND)>>> dataend1tensor([[[[-1.3587, -0.9213, -0.7269,  ..., -0.3382, -0.3868, -0.4516],          [-1.4397, -0.8727, -0.6135,  ..., -0.1114, -0.1762, -0.2248],          [-1.8771, -1.3587, -0.9375,  ...,  0.1640,  0.0830, -0.1438],          ...,          [-1.9095, -1.8285, -1.8123,  ..., -2.1687, -2.2497, -2.2173],          [-1.9419, -1.8609, -1.8123,  ..., -2.3469, -2.4117, -2.2983],          [-1.9257, -1.8447, -1.8447,  ..., -2.3307, -2.3307, -2.2821]],
         [[-1.0502, -0.4357, -0.0055,  ...,  0.4246,  0.3785,  0.3325],          [-1.1424, -0.4203,  0.0406,  ...,  0.5783,  0.5475,  0.5168],          [-1.6340, -1.0656, -0.4664,  ...,  0.7626,  0.7319,  0.5629],          ...,          [-1.6340, -1.5572, -1.5418,  ..., -1.6955, -1.7723, -1.7876],          [-1.6801, -1.5879, -1.5418,  ..., -1.9413, -2.0027, -1.8491],          [-1.6186, -1.5726, -1.5726,  ..., -1.9259, -1.8952, -1.8645]],
         [[-0.4938,  0.0881,  0.5988,  ...,  1.0026,  0.9788,  0.9313],          [-0.5888,  0.0762,  0.6107,  ...,  1.0738,  1.0501,  1.0382],          [-1.0758, -0.5532,  0.1000,  ...,  1.1807,  1.1570,  1.0501],          ...,          [-0.9926, -0.9332, -0.9451,  ..., -1.3845, -1.4083, -1.3845],          [-1.0401, -0.9689, -0.9570,  ..., -1.4320, -1.4439, -1.3370],          [-0.9926, -0.9451, -0.9570,  ..., -1.4320, -1.4558, -1.3964]]],        [[[-1.6827, -1.8609, -1.9095,  ..., -0.4192, -0.4840, -0.5002],          [-1.6989, -1.8285, -1.8933,  ..., -0.3868, -0.4678, -0.4516],          [-1.6989, -1.7961, -2.0877,  ..., -0.3868, -0.4192, -0.4516],          ...,          [ 0.7634,  0.8606,  0.8768,  ...,  0.9254,  0.9092,  0.9092],          [ 0.8120,  0.8930,  0.8930,  ...,  0.9416,  0.8930,  0.8930],          [ 0.8282,  0.9092,  0.9254,  ...,  0.9254,  0.8930,  0.8930]],
         [[-1.9413, -2.0334, -1.9720,  ..., -1.6340, -1.6340, -1.6340],          [-1.9413, -2.0181, -1.9720,  ..., -1.5879, -1.5572, -1.5572],          [-1.9413, -1.9874, -2.0488,  ..., -1.5726, -1.5265, -1.5265],          ...,          [ 0.5936,  0.7473,  0.7473,  ...,  0.8702,  0.8394,  0.8241],          [ 0.6397,  0.7780,  0.7780,  ...,  0.8702,  0.7780,  0.8087],          [ 0.7319,  0.8241,  0.8241,  ...,  0.8394,  0.8087,  0.7933]],
         [[-1.3608, -1.3845, -1.3370,  ..., -1.2539, -1.2301, -1.2301],          [-1.3608, -1.3845, -1.3252,  ..., -1.2183, -1.2064, -1.2064],          [-1.3608, -1.3727, -1.3964,  ..., -1.2064, -1.1826, -1.1826],          ...,          [ 0.5988,  0.7532,  0.7532,  ...,  0.8719,  0.8363,  0.8363],          [ 0.6700,  0.7888,  0.8007,  ...,  0.8719,  0.7650,  0.8126],          [ 0.7532,  0.8244,  0.8363,  ...,  0.8482,  0.8126,  0.8007]]],        [[[ 0.6986,  0.8282,  0.7796,  ...,  0.1640,  0.0830,  0.1316],          [ 0.3908,  0.5204,  0.5852,  ...,  0.1964,  0.2774,  0.2126],          [ 0.4070,  0.4880,  0.6014,  ...,  0.0182,  0.3746,  0.2612],          ...,          [-0.3706, -0.6135, -0.4030,  ..., -0.2248, -0.2572, -0.2086],          [-0.4516, -0.6783, -1.0185,  ..., -0.3220, -0.3868, -0.4030],          [-0.5973, -0.5973, -1.0347,  ..., -0.3868, -0.4678, -0.5649]],
         [[ 0.6551,  0.7780,  0.6551,  ..., -0.2360, -0.2513,  0.1020],          [ 0.2249,  0.3478,  0.3939,  ..., -0.1899, -0.1438,  0.0252],          [ 0.2096,  0.2864,  0.3785,  ..., -0.3282, -0.0363, -0.0055],          ...,          [-0.1592, -0.5586, -0.6661,  ..., -0.0055, -0.0363, -0.0055],          [-0.2360, -0.5740, -1.1424,  ..., -0.0977, -0.1284, -0.1438],          [-0.3896, -0.4203, -0.9888,  ..., -0.1899, -0.2206, -0.2974]],
         [[-0.2088, -0.0782, -0.2919,  ..., -0.5770, -0.5413,  0.0050],          [-0.7670, -0.6720, -0.6720,  ..., -0.6363, -0.5532, -0.1257],          [-0.8145, -0.7432, -0.7195,  ..., -0.6720, -0.5770, -0.2919],          ...,          [-1.4202, -1.3845, -1.0282,  ..., -1.0282, -1.0045, -0.9689],          [-1.4320, -1.4202, -1.2658,  ..., -1.0758, -1.0401, -1.0282],          [-1.4320, -1.4202, -1.4202,  ..., -1.0758, -1.0995, -1.0995]]],        [[[ 0.7958,  0.7958,  0.8120,  ...,  0.7958,  0.7958,  0.7958],          [ 0.7958,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          [ 0.8120,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          ...,          [ 0.7958,  0.7958,  0.7958,  ...,  0.8120,  0.7958,  0.7796],          [ 0.8444,  0.8444,  0.8606,  ...,  0.8930,  0.8930,  0.8768],          [ 0.8606,  0.8606,  0.8606,  ...,  0.8930,  0.8930,  0.8930]],
         [[ 0.9623,  0.9623,  0.9777,  ...,  0.9623,  0.9623,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          ...,          [ 0.9623,  0.9623,  0.9623,  ...,  0.9623,  0.9623,  0.9470],          [ 1.0084,  1.0084,  1.0238,  ...,  1.0545,  1.0545,  1.0392],          [ 1.0238,  1.0238,  1.0238,  ...,  1.0545,  1.0545,  1.0545]],
         [[ 1.2638,  1.2638,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.2638,  1.2757,  1.2876,  ...,  1.2757,  1.2757,  1.2638],          [ 1.2995,  1.2995,  1.2995,  ...,  1.2757,  1.2757,  1.2638],          ...,          [ 1.2876,  1.2876,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.3232,  1.3232,  1.3114,  ...,  1.3351,  1.3351,  1.3232],          [ 1.3351,  1.3351,  1.3114,  ...,  1.3351,  1.3351,  1.3351]]]])

现在按变换公式编程进行计算:​​​​​​​

>>> enddata=(IMGEND-mean)/std>>> enddatatensor([[[[-1.3587, -0.9213, -0.7269,  ..., -0.3382, -0.3868, -0.4516],          [-1.4397, -0.8727, -0.6135,  ..., -0.1114, -0.1762, -0.2248],          [-1.8771, -1.3587, -0.9375,  ...,  0.1640,  0.0830, -0.1438],          ...,          [-1.9095, -1.8285, -1.8123,  ..., -2.1687, -2.2497, -2.2173],          [-1.9419, -1.8609, -1.8123,  ..., -2.3469, -2.4117, -2.2983],          [-1.9257, -1.8447, -1.8447,  ..., -2.3307, -2.3307, -2.2821]],
         [[-1.0502, -0.4357, -0.0055,  ...,  0.4246,  0.3785,  0.3325],          [-1.1424, -0.4203,  0.0406,  ...,  0.5783,  0.5475,  0.5168],          [-1.6340, -1.0656, -0.4664,  ...,  0.7626,  0.7319,  0.5629],          ...,          [-1.6340, -1.5572, -1.5418,  ..., -1.6955, -1.7723, -1.7876],          [-1.6801, -1.5879, -1.5418,  ..., -1.9413, -2.0027, -1.8491],          [-1.6186, -1.5726, -1.5726,  ..., -1.9259, -1.8952, -1.8645]],
         [[-0.4938,  0.0881,  0.5988,  ...,  1.0026,  0.9788,  0.9313],          [-0.5888,  0.0762,  0.6107,  ...,  1.0738,  1.0501,  1.0382],          [-1.0758, -0.5532,  0.1000,  ...,  1.1807,  1.1570,  1.0501],          ...,          [-0.9926, -0.9332, -0.9451,  ..., -1.3845, -1.4083, -1.3845],          [-1.0401, -0.9689, -0.9570,  ..., -1.4320, -1.4439, -1.3370],          [-0.9926, -0.9451, -0.9570,  ..., -1.4320, -1.4558, -1.3964]]],        [[[-1.6827, -1.8609, -1.9095,  ..., -0.4192, -0.4840, -0.5002],          [-1.6989, -1.8285, -1.8933,  ..., -0.3868, -0.4678, -0.4516],          [-1.6989, -1.7961, -2.0877,  ..., -0.3868, -0.4192, -0.4516],          ...,          [ 0.7634,  0.8606,  0.8768,  ...,  0.9254,  0.9092,  0.9092],          [ 0.8120,  0.8930,  0.8930,  ...,  0.9416,  0.8930,  0.8930],          [ 0.8282,  0.9092,  0.9254,  ...,  0.9254,  0.8930,  0.8930]],
         [[-1.9413, -2.0334, -1.9720,  ..., -1.6340, -1.6340, -1.6340],          [-1.9413, -2.0181, -1.9720,  ..., -1.5879, -1.5572, -1.5572],          [-1.9413, -1.9874, -2.0488,  ..., -1.5726, -1.5265, -1.5265],          ...,          [ 0.5936,  0.7473,  0.7473,  ...,  0.8702,  0.8394,  0.8241],          [ 0.6397,  0.7780,  0.7780,  ...,  0.8702,  0.7780,  0.8087],          [ 0.7319,  0.8241,  0.8241,  ...,  0.8394,  0.8087,  0.7933]],
         [[-1.3608, -1.3845, -1.3370,  ..., -1.2539, -1.2301, -1.2301],          [-1.3608, -1.3845, -1.3252,  ..., -1.2183, -1.2064, -1.2064],          [-1.3608, -1.3727, -1.3964,  ..., -1.2064, -1.1826, -1.1826],          ...,          [ 0.5988,  0.7532,  0.7532,  ...,  0.8719,  0.8363,  0.8363],          [ 0.6700,  0.7888,  0.8007,  ...,  0.8719,  0.7650,  0.8126],          [ 0.7532,  0.8244,  0.8363,  ...,  0.8482,  0.8126,  0.8007]]],        [[[ 0.6986,  0.8282,  0.7796,  ...,  0.1640,  0.0830,  0.1316],          [ 0.3908,  0.5204,  0.5852,  ...,  0.1964,  0.2774,  0.2126],          [ 0.4070,  0.4880,  0.6014,  ...,  0.0182,  0.3746,  0.2612],          ...,          [-0.3706, -0.6135, -0.4030,  ..., -0.2248, -0.2572, -0.2086],          [-0.4516, -0.6783, -1.0185,  ..., -0.3220, -0.3868, -0.4030],          [-0.5973, -0.5973, -1.0347,  ..., -0.3868, -0.4678, -0.5650]],
         [[ 0.6551,  0.7780,  0.6551,  ..., -0.2360, -0.2513,  0.1020],          [ 0.2249,  0.3478,  0.3939,  ..., -0.1899, -0.1438,  0.0252],          [ 0.2096,  0.2864,  0.3785,  ..., -0.3282, -0.0363, -0.0055],          ...,          [-0.1592, -0.5586, -0.6661,  ..., -0.0055, -0.0363, -0.0055],          [-0.2360, -0.5740, -1.1424,  ..., -0.0977, -0.1284, -0.1438],          [-0.3896, -0.4203, -0.9888,  ..., -0.1899, -0.2206, -0.2974]],
         [[-0.2088, -0.0782, -0.2919,  ..., -0.5770, -0.5413,  0.0050],          [-0.7670, -0.6720, -0.6720,  ..., -0.6363, -0.5532, -0.1257],          [-0.8145, -0.7432, -0.7195,  ..., -0.6720, -0.5770, -0.2919],          ...,          [-1.4202, -1.3845, -1.0282,  ..., -1.0282, -1.0045, -0.9689],          [-1.4320, -1.4202, -1.2658,  ..., -1.0758, -1.0401, -1.0282],          [-1.4320, -1.4202, -1.4202,  ..., -1.0758, -1.0995, -1.0995]]],        [[[ 0.7958,  0.7958,  0.8120,  ...,  0.7958,  0.7958,  0.7958],          [ 0.7958,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          [ 0.8120,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          ...,          [ 0.7958,  0.7958,  0.7958,  ...,  0.8120,  0.7958,  0.7796],          [ 0.8444,  0.8444,  0.8606,  ...,  0.8930,  0.8930,  0.8768],          [ 0.8606,  0.8606,  0.8606,  ...,  0.8930,  0.8930,  0.8930]],
         [[ 0.9623,  0.9623,  0.9777,  ...,  0.9623,  0.9623,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          ...,          [ 0.9623,  0.9623,  0.9623,  ...,  0.9623,  0.9623,  0.9470],          [ 1.0084,  1.0084,  1.0238,  ...,  1.0545,  1.0545,  1.0392],          [ 1.0238,  1.0238,  1.0238,  ...,  1.0545,  1.0545,  1.0545]],
         [[ 1.2638,  1.2638,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.2638,  1.2757,  1.2876,  ...,  1.2757,  1.2757,  1.2638],          [ 1.2995,  1.2995,  1.2995,  ...,  1.2757,  1.2757,  1.2638],          ...,          [ 1.2876,  1.2876,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.3232,  1.3232,  1.3114,  ...,  1.3351,  1.3351,  1.3232],          [ 1.3351,  1.3351,  1.3114,  ...,  1.3351,  1.3351,  1.3351]]]])

显然,两次结果一样,这也说明transforms.Normalize的实质就是使用该公式对输入数据进行变换。

同时,当transforms.Normalize接受的均值和标准差为待变换数据的均值和标准差时,按照此公式变换,得到的新的数据服从的分布一定是均值为0,标准差为1的分布

而当transforms.Normalize接受的均值和标准差不是待变换数据的均值和标准差时,所得的新数据均值未必为0,标准差也未必为1,仅仅是按照公式变换了数据而已。

就像这样:​​​​​​​

>>> process=transforms.Normalize([0.5, 0.6, 0.4],[0.36, 0.45, 0.45])>>> data=process(inputdata)

这里[0.5, 0.6, 0.4],[0.36, 0.45, 0.45]并不是inputdata的均值和标准差,是随意给的,仅仅是想对原数据进行变换,那么得到的新数据均值自然不一定为0,标准差也不一定为1。

当然,在我们对图片进行预处理时,往往会看到这两行代码一起出现:​​​​​​​

transform=transforms.Compose([transforms.ToTensor(),                            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])                            )

这里的transforms.ToTensor()的作用就是

将输入的数据变为张量,同时shape由 W,H,C ——> C,W,H, 同时,将所有数除以255,将数据归一化到[0,1]。

根据公式:x=(x-mean)/std

得:

(0-0.5)/0.5=-1

(1-0.5)/0.5=1

可以发现:新的数据分布为[-1,1],但是新的数据均值未必为0,同时标准差也未必为0,这点需要明白

之所以这样,是因为这里的[0.5,0.5,0.5],[0.5,0.5,0.5]并不一定就是原数据的均值和标准差。

这篇关于PyTorch|transforms.Normalize的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/574724

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl