AI绘画训练一个扩散模型-上集

2023-12-25 06:44

本文主要是介绍AI绘画训练一个扩散模型-上集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。
对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。

  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。

  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

# 创建UNet和扩散模型from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torchmodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # number of steps
).cuda()# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:sampled_images = sampled_images.cpu()# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。

  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。

  • 创建一个目录来保存图片。

  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。

  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。

  • 使用tqdm显示循环进度。

# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdmdataset = load_dataset("nelorth/oxford-flowers")# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():for index, item in enumerate(tqdm(dataset[split])):image = item['image']image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。

  • 加载数据,指定训练参数:

    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。

# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainermodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # 加噪总步数
).cuda()trainer = Trainer(diffusion,'./oxford-datasets/raw-images',train_batch_size = 16,train_lr = 2e-5,train_num_steps = 20000,          # 总共训练20000步gradient_accumulate_every = 2,    # 梯度累积步数ema_decay = 0.995,                # 指数滑动平均decay参数amp = True,                       # 使用混合精度训练加速calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。save_and_sample_every = 2000      # 每隔2000步保存一次模型
)trainer.train()
# 你可以等待上面的模型训练完成后,查看生成结果from glob import globresult_images = glob(r"./results/*.png")
print(result_images)
# 可视化图像看看
from PIL import Imageimg = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

这篇关于AI绘画训练一个扩散模型-上集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/534568

相关文章

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费