手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构

2024-02-27 09:20

本文主要是介绍手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一. 构建Dataset

构建Dataset无非就是创建一个类继承dataset并重写三个方法:

from torch.utils.data import Dataset
from PIL import Imageclass MyDataset(Dataset):def __init__(self):# 初始化数据集passdef __getitem__(self, index):# 根据索引获取数据样本passdef __len__(self):# 返回数据集大小pass# 创建自定义数据集实例
dataset = MyDataset()

1. __init__(self,params...) 方法

通过创建实例输入实例属性,如数据集文件夹地址;

需设置一些其他方法所需的属性,如self.data_path指明数据所在文件夹等一些属性。

def __init__(self, data_path):self.data_path = data_pathself.data_list = os.listdir(self.data_path)

2. __getitem__(self, index)方法

参数index一般用于data_list[index] 指定某个与index捆绑的实例。

该方法需返回所需的数据,返回的数据相当于Dataset的一个实例。

    def __getitem__(self, index):file_name = self.data_list[index] #用index指定文件data_label = file_name.split('.')[0] #将标签置于文件名data_path = os.path.join(self.data_path, file_name) #获取文件地址img = Image.open(data_path) #读取数据return img, data_label #返回实例数据

3. __len__(self)方法

返回数据集实例的数量,也就是数据集的大小。

    def __len__(self):return len(self.data_list)

二. 创建dataloader

dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

#导入dataloader的包
from torch.utils.data import DataLoader#读取文件夹下数据以创建数据集
test_Dataset = MyDataset(dir_address)#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,
#不进行多进程读取机制,不舍弃不能被整除的批次
dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

三. 搭建模型基础架构(pytorch版)

1. 导包(torch,nn,tqdm几乎必须)

import torch
from torch import nn
from tqdm import tqdm

2. 搭建通用核心训练架构函数

def train_loop(model, loader, epochs, optim, device, display=False, store_path='model.pt'):mse = nn.MSELoss()best_loss = float('inf')for epoch in tqdm(range(epochs),desc=f"Training process", colour='#00f00'):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):#loading data,对应了Dataset返回的实例数据x = batch[0].to(device)y = batch[1].to(device)n = len(x) #batchsizeloss = mse(x, y) # the average loss of one batchoptim.zero_grad()optim.step()epoch_loss += loss.item() * n / len(loader.dataset)#display results generated at this epochif display:passlog_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"#save the modelif best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever(stored)"print(log_string)

这篇关于手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751943

相关文章

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

搭建Kafka+zookeeper集群调度

前言 硬件环境 172.18.0.5        kafkazk1        Kafka+zookeeper                Kafka Broker集群 172.18.0.6        kafkazk2        Kafka+zookeeper                Kafka Broker集群 172.18.0.7        kafkazk3

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验