本文主要是介绍内涵:pyTorch学习之加载自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pyTorch根据filelist加载自己的数据集合,无论图片是否在一个文件夹还是一个类的图片在一个文件夹。
第一步:继承实现Dataset类别
def default_loader(path):return Image.open(path).convert('RGB')
class MyDataset(Dataset):def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):fh = open(txt, 'r')imgs = []for line in fh:line = line.strip('\n')line = line.rstrip()words = line.split()imgs.append((words[0],int(words[1])))self.imgs = imgsself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):fn, label = self.imgs[index]img = self.loader(fn)if self.transform is not None:img = self.transform(img)else:img = Tensor.from_numpy(img)return img,label
def __len__(self): return len(self.imgs)
第二步骤:就直接可以用自己定义的这个类,来构建自己的dataset了
transform = transforms.Compose([transforms.Scale((227,227)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train_data = MyDataset(txt='train_filelist.txt',transform=transform)
其中比较有用的一个点是
transforms.Scale((227,227))
用来将不同大小的图片resize到统一尺寸。
还有一个点就是,彩色图片都要做的归一化
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
这篇关于内涵:pyTorch学习之加载自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!