【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

2024-08-28 16:20

本文主要是介绍【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)# optimizerparams_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)...# Prepare everything with our `accelerator`.ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memoryunet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)#ip-adapter-plusimage_proj_model = Resampler(dim=unet.config.cross_attention_dim,depth=4,dim_head=64,heads=12,num_queries=args.num_tokens,embedding_dim=image_encoder.config.hidden_size,output_dim=unet.config.cross_attention_dim,ff_mult=4)...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modulesattn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):load_data_time = time.perf_counter() - beginwith accelerator.accumulate(ip_adapter):# Convert images to latent spacewith torch.no_grad():latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor# Sample noise that we'll add to the latentsnoise = torch.randn_like(latents)bsz = latents.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)timesteps = timesteps.long()# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):if drop_image_embed == 1:clip_images.append(torch.zeros_like(clip_image))else:clip_images.append(clip_image)clip_images = torch.stack(clip_images, dim=0)with torch.no_grad():image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]with torch.no_grad():encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎

这篇关于【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115272

相关文章

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

Spring Boot结成MyBatis-Plus最全配置指南

《SpringBoot结成MyBatis-Plus最全配置指南》本文主要介绍了SpringBoot结成MyBatis-Plus最全配置指南,包括依赖引入、配置数据源、Mapper扫描、基本CRUD操... 目录前言详细操作一.创建项目并引入相关依赖二.配置数据源信息三.编写相关代码查zsRArly询数据库数

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Linux系统中配置静态IP地址的详细步骤

《Linux系统中配置静态IP地址的详细步骤》本文详细介绍了在Linux系统中配置静态IP地址的五个步骤,包括打开终端、编辑网络配置文件、配置IP地址、保存并重启网络服务,这对于系统管理员和新手都极具... 目录步骤一:打开终端步骤二:编辑网络配置文件步骤三:配置静态IP地址步骤四:保存并关闭文件步骤五:重

Linux配置IP地址的三种实现方式

《Linux配置IP地址的三种实现方式》:本文主要介绍Linux配置IP地址的三种实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录环境RedHat9第一种安装 直接配置网卡文件第二种方式 nmcli(Networkmanager command-line

Spring 中 BeanFactoryPostProcessor 的作用和示例源码分析

《Spring中BeanFactoryPostProcessor的作用和示例源码分析》Spring的BeanFactoryPostProcessor是容器初始化的扩展接口,允许在Bean实例化前... 目录一、概览1. 核心定位2. 核心功能详解3. 关键特性二、Spring 内置的 BeanFactory

mybatis-plus分页无效问题解决

《mybatis-plus分页无效问题解决》本文主要介绍了mybatis-plus分页无效问题解决,原因是配置分页插件的版本问题,旧版本和新版本的MyBatis-Plus需要不同的分页配置,感兴趣的可... 昨天在做一www.chinasem.cn个新项目使用myBATis-plus分页一直失败,后来经过多方

mybatis-plus 实现查询表名动态修改的示例代码

《mybatis-plus实现查询表名动态修改的示例代码》通过MyBatis-Plus实现表名的动态替换,根据配置或入参选择不同的表,本文主要介绍了mybatis-plus实现查询表名动态修改的示... 目录实现数据库初始化依赖包配置读取类设置 myBATis-plus 插件测试通过 mybatis-plu