本文主要是介绍Datawhale 零基础入门CV-Task02.数据读取与数据扩增,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
主要内容
- 数据读取
- 数据扩增方法
Pytorch
读取赛题数据
学习目标
- 学会
Python
和Pytorch
中图像读取 - 学会扩增方法和
Pytorch
读取赛题数据
图像读取
- 由于赛题数据是图像数据,赛题的任务是识别图像中的字符。因此需要完成对数据的读取操作,在
Python
中有很多库可以完成数据读取的操作,比较常见的有Pillow
和OpenCV
Pillow
Pillow
是Python
图像处理函数库PIL
的一个分支,Pillow
提供了常见的图像读取和处理的操作,而且可以与ipython notebook
无缝集成,是应用比较广泛的库
- 实现
from PIL import Image,ImageFilter
im = Image.open(r"D:\input\mchar_train\timg.JFIF")
plt.imshow(im)
- 应用模糊滤镜
- 首先可以利用系统自带的画图工具转为
jpg
格式 - 实现应用模糊滤镜
from PIL import Image,ImageFilter,ImageFilter
im = Image.open(r"D:\input\mchar_train\timg.jpg")
im2 = im.filter(ImageFilter.BLUR)
im2.save('blur.jpg','jpeg')
plt.imshow(im2)
- 图片放缩
Pillow官方文档
OpenCV
OpenCV
是一个跨平台的计算机视觉库,最早由Intel
开源得来,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV
在功能上比Pillow
更强大
- 实现
# 库在前面已经导入过了
import cv2
img = cv2.imread(r"D:\input\mchar_train\mchar_train\000000.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
plt.imshow(img)
OpenCV官网
OpenCV扩展算法库
数据扩增方法
- 在赛题中需要对图像进行字符识别,因此需要完成数据的读取操作同时也需要完成数据扩增操作
数据扩增介绍
- 数据扩增可以增加训练集的样本,同时可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力
- 数据扩增的作用:数据扩增可以扩展样本空间
数据扩增方法
- 从颜色空间、尺度空间到样本空间,同时根据不同任务数据扩增都有相应的区别
- 对于图像分类,数据扩增一般不会改变标签:对于物体检测、数据扩增会改变物体坐标位置;对于图像分割,数据扩增会改变像素标签
常见的数据扩增方法 - 在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。不同的数据扩增方法可以自由进行组合,得到更丰富的数据扩增方法,下面给出以
torchvision
为例,常见的数据扩增方法
transforms.CenterCrop
:对图片中心进行裁剪
thansforms.ColorJitter
:对图像颜色的对比度、饱和度和零度进行变换
transforms.FiveCrop
:对图像四个角和中心进行剪裁得到五分图像
transforms.Grayscale
:对图像进行灰度变换
transforms.Pad
:使用固定值进行像素填充
transforms.RandomAffine
:随机仿射变换
transforms.RandomCrop
:随机区域裁剪
transforms.RandomHorizontalFlip
:随机水平翻转
transforms.RandomRotation
:随即旋转
transforms.RandomVerticalFilp
:随机垂直翻转
- 对于图像中的字符进行识别,不能进行翻转操作,翻转后可能改变字符原本的含义
常用的数据扩增库
torchvision
:pytorch
官方提供的数据扩增库,提供了基本的数据扩增方法,可以与torch
进行集成,但数据扩增方法种类较少,速度中等
github
imagaug
:常用的第三方数据扩增库,提供了多样的数据扩增方法,组合起来比较方便,速度较快
github
albumentations
:常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割,物体检测和关键点检测都支持,速度较快
使用文档
Pytorch读取数据
- 在
Pytorch
中数据是通过Dataset
进行封装,并通过DataLoder
进行并行读取,所以只需重载一下数据读取的逻辑就可以完成数据的读取
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):def __init__(self, img_path, img_label, transform=None):self.img_path = img_pathself.img_label = img_labelif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):img = Image.open(self.img_path[index]).convert('RGB')if self.transform is not None:img = self.transform(img)# 原始SVHN中类别10为数字0lbl = np.array(self.img_label[index], dtype=np.int)lbl = list(lbl) + (5 - len(lbl)) * [10]return img, torch.from_numpy(np.array(lbl[:5]))def __len__(self):return len(self.img_path)
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open(r"D:\input\mchar_train.json"))
train_label = [train_json[x]['label'] for x in train_json]data = SVHNDataset(train_path, train_label,transforms.Compose([# 缩放到固定尺⼨transforms.Resize((64, 128)),# 随机颜⾊变换transforms.ColorJitter(0.2, 0.2, 0.2),# 加⼊随机旋转transforms.RandomRotation(5),# 将图⽚转换为pytorch 的tesntortransforms.ToTensor(),# 对图像像素进⾏归⼀化transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]))
Dataset
:对数据集的封装,提供索引方式的对数据样本进行读取DataLoder
:对Dataset
进行封装,提供批量读取的迭代读取- 加入
DataLoder
后,数据读取代码改写如下
import os, sys, glob, shutil, json
import cv2from PIL import Image
import numpy as npimport torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transformsclass SVHNDataset(Dataset):def __init__(self, img_path, img_label, transform=None):self.img_path = img_pathself.img_label = img_labelif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):img = Image.open(self.img_path[index]).convert('RGB')if self.transform is not None:img = self.transform(img)# 原始SVHN中类别10为数字0lbl = np.array(self.img_label[index], dtype=np.int)lbl = list(lbl) + (5 - len(lbl)) * [10]return img, torch.from_numpy(np.array(lbl[:5]))def __len__(self):return len(self.img_path)train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,transforms.Compose([transforms.Resize((64, 128)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])),batch_size=10, # 每批样本个数shuffle=False, # 是否打乱顺序num_workers=10, # 读取的线程个数
)for data in train_loader:break
- 加入
DataLoder
后,数据按照批次获取,每批次调用Dataset
读取单个样本进行拼接,此时data
的格式为:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
- 前者为图像文件,为
batchsize * chanel * height * width
次序;后者为字符标签
本章小结
- 对数据读取进行详细了解,学会常见的数据扩增方法和使用,最后使用
Pytorch
框架对赛题的数据进行读取
这篇关于Datawhale 零基础入门CV-Task02.数据读取与数据扩增的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!