手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构

2024-02-27 09:20

本文主要是介绍手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一. 构建Dataset

构建Dataset无非就是创建一个类继承dataset并重写三个方法:

from torch.utils.data import Dataset
from PIL import Imageclass MyDataset(Dataset):def __init__(self):# 初始化数据集passdef __getitem__(self, index):# 根据索引获取数据样本passdef __len__(self):# 返回数据集大小pass# 创建自定义数据集实例
dataset = MyDataset()

1. __init__(self,params...) 方法

通过创建实例输入实例属性,如数据集文件夹地址;

需设置一些其他方法所需的属性,如self.data_path指明数据所在文件夹等一些属性。

def __init__(self, data_path):self.data_path = data_pathself.data_list = os.listdir(self.data_path)

2. __getitem__(self, index)方法

参数index一般用于data_list[index] 指定某个与index捆绑的实例。

该方法需返回所需的数据,返回的数据相当于Dataset的一个实例。

    def __getitem__(self, index):file_name = self.data_list[index] #用index指定文件data_label = file_name.split('.')[0] #将标签置于文件名data_path = os.path.join(self.data_path, file_name) #获取文件地址img = Image.open(data_path) #读取数据return img, data_label #返回实例数据

3. __len__(self)方法

返回数据集实例的数量,也就是数据集的大小。

    def __len__(self):return len(self.data_list)

二. 创建dataloader

dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

#导入dataloader的包
from torch.utils.data import DataLoader#读取文件夹下数据以创建数据集
test_Dataset = MyDataset(dir_address)#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,
#不进行多进程读取机制,不舍弃不能被整除的批次
dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

三. 搭建模型基础架构(pytorch版)

1. 导包(torch,nn,tqdm几乎必须)

import torch
from torch import nn
from tqdm import tqdm

2. 搭建通用核心训练架构函数

def train_loop(model, loader, epochs, optim, device, display=False, store_path='model.pt'):mse = nn.MSELoss()best_loss = float('inf')for epoch in tqdm(range(epochs),desc=f"Training process", colour='#00f00'):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):#loading data,对应了Dataset返回的实例数据x = batch[0].to(device)y = batch[1].to(device)n = len(x) #batchsizeloss = mse(x, y) # the average loss of one batchoptim.zero_grad()optim.step()epoch_loss += loss.item() * n / len(loader.dataset)#display results generated at this epochif display:passlog_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"#save the modelif best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever(stored)"print(log_string)

这篇关于手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751943

相关文章

Spring AI集成DeepSeek的详细步骤

《SpringAI集成DeepSeek的详细步骤》DeepSeek作为一款卓越的国产AI模型,越来越多的公司考虑在自己的应用中集成,对于Java应用来说,我们可以借助SpringAI集成DeepSe... 目录DeepSeek 介绍Spring AI 是什么?1、环境准备2、构建项目2.1、pom依赖2.2

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

本地搭建DeepSeek-R1、WebUI的完整过程及访问

《本地搭建DeepSeek-R1、WebUI的完整过程及访问》:本文主要介绍本地搭建DeepSeek-R1、WebUI的完整过程及访问的相关资料,DeepSeek-R1是一个开源的人工智能平台,主... 目录背景       搭建准备基础概念搭建过程访问对话测试总结背景       最近几年,人工智能技术

SpringBoot整合DeepSeek实现AI对话功能

《SpringBoot整合DeepSeek实现AI对话功能》本文介绍了如何在SpringBoot项目中整合DeepSeekAPI和本地私有化部署DeepSeekR1模型,通过SpringAI框架简化了... 目录Spring AI版本依赖整合DeepSeek API key整合本地化部署的DeepSeek

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

nginx-rtmp-module构建流媒体直播服务器实战指南

《nginx-rtmp-module构建流媒体直播服务器实战指南》本文主要介绍了nginx-rtmp-module构建流媒体直播服务器实战指南,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. RTMP协议介绍与应用RTMP协议的原理RTMP协议的应用RTMP与现代流媒体技术的关系2

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选