本文主要是介绍datasets.ImageFolder和train_dataset.class_to_idx的用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
datasets.ImageFolder用法,是将文件夹的名字转化为标签。用于分类任务。
from torchvision import datasetstrain_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])
比如在flower_photos文件夹下存放着五个子文件夹,分别存放着各种类别的图像。
|-- flower_photos
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
使用datasets.ImageFolder后就可以将daisy转化为0,dandelion为1......。
flower_list = train_dataset.class_to_idx
这行代码可以获取数据的类别数以及对应的类别标签。以字典的形式保存。
输出为:
{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
然后就可以使用torch.utils.data.DataLoader加载了。
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)
完整代码
可以先将数据集划分好以下各市
|-- train
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
|-- val
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
然后就可以进行以下加载了
def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join( "/kaggle/working/", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))
这篇关于datasets.ImageFolder和train_dataset.class_to_idx的用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!