本文主要是介绍手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一. 构建Dataset
构建Dataset无非就是创建一个类继承dataset并重写三个方法:
from torch.utils.data import Dataset
from PIL import Imageclass MyDataset(Dataset):def __init__(self):# 初始化数据集passdef __getitem__(self, index):# 根据索引获取数据样本passdef __len__(self):# 返回数据集大小pass# 创建自定义数据集实例
dataset = MyDataset()
1. __init__(self,params...) 方法
通过创建实例输入实例属性,如数据集文件夹地址;
需设置一些其他方法所需的属性,如self.data_path指明数据所在文件夹等一些属性。
def __init__(self, data_path):self.data_path = data_pathself.data_list = os.listdir(self.data_path)
2. __getitem__(self, index)方法
参数index一般用于data_list[index] 指定某个与index捆绑的实例。
该方法需返回所需的数据,返回的数据相当于Dataset的一个实例。
def __getitem__(self, index):file_name = self.data_list[index] #用index指定文件data_label = file_name.split('.')[0] #将标签置于文件名data_path = os.path.join(self.data_path, file_name) #获取文件地址img = Image.open(data_path) #读取数据return img, data_label #返回实例数据
3. __len__(self)方法
返回数据集实例的数量,也就是数据集的大小。
def __len__(self):return len(self.data_list)
二. 创建dataloader
dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据
#导入dataloader的包
from torch.utils.data import DataLoader#读取文件夹下数据以创建数据集
test_Dataset = MyDataset(dir_address)#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,
#不进行多进程读取机制,不舍弃不能被整除的批次
dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
三. 搭建模型基础架构(pytorch版)
1. 导包(torch,nn,tqdm几乎必须)
import torch
from torch import nn
from tqdm import tqdm
2. 搭建通用核心训练架构函数
def train_loop(model, loader, epochs, optim, device, display=False, store_path='model.pt'):mse = nn.MSELoss()best_loss = float('inf')for epoch in tqdm(range(epochs),desc=f"Training process", colour='#00f00'):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):#loading data,对应了Dataset返回的实例数据x = batch[0].to(device)y = batch[1].to(device)n = len(x) #batchsizeloss = mse(x, y) # the average loss of one batchoptim.zero_grad()optim.step()epoch_loss += loss.item() * n / len(loader.dataset)#display results generated at this epochif display:passlog_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"#save the modelif best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever(stored)"print(log_string)
这篇关于手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!