【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

2024-08-28 16:20

本文主要是介绍【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)# optimizerparams_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)...# Prepare everything with our `accelerator`.ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memoryunet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)#ip-adapter-plusimage_proj_model = Resampler(dim=unet.config.cross_attention_dim,depth=4,dim_head=64,heads=12,num_queries=args.num_tokens,embedding_dim=image_encoder.config.hidden_size,output_dim=unet.config.cross_attention_dim,ff_mult=4)...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modulesattn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):load_data_time = time.perf_counter() - beginwith accelerator.accumulate(ip_adapter):# Convert images to latent spacewith torch.no_grad():latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor# Sample noise that we'll add to the latentsnoise = torch.randn_like(latents)bsz = latents.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)timesteps = timesteps.long()# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):if drop_image_embed == 1:clip_images.append(torch.zeros_like(clip_image))else:clip_images.append(clip_image)clip_images = torch.stack(clip_images, dim=0)with torch.no_grad():image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]with torch.no_grad():encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎

这篇关于【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115272

相关文章

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

shell脚本快速检查192.168.1网段ip是否在用的方法

《shell脚本快速检查192.168.1网段ip是否在用的方法》该Shell脚本通过并发ping命令检查192.168.1网段中哪些IP地址正在使用,脚本定义了网络段、超时时间和并行扫描数量,并使用... 目录脚本:检查 192.168.1 网段 IP 是否在用脚本说明使用方法示例输出优化建议总结检查 1

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

SpringBoot基于MyBatis-Plus实现Lambda Query查询的示例代码

《SpringBoot基于MyBatis-Plus实现LambdaQuery查询的示例代码》MyBatis-Plus是MyBatis的增强工具,简化了数据库操作,并提高了开发效率,它提供了多种查询方... 目录引言基础环境配置依赖配置(Maven)application.yml 配置表结构设计demo_st

解决mybatis-plus-boot-starter与mybatis-spring-boot-starter的错误问题

《解决mybatis-plus-boot-starter与mybatis-spring-boot-starter的错误问题》本文主要讲述了在使用MyBatis和MyBatis-Plus时遇到的绑定异常... 目录myBATis-plus-boot-starpythonter与mybatis-spring-b

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

SpringBoot实现基于URL和IP的访问频率限制

《SpringBoot实现基于URL和IP的访问频率限制》在现代Web应用中,接口被恶意刷新或暴力请求是一种常见的攻击手段,为了保护系统资源,需要对接口的访问频率进行限制,下面我们就来看看如何使用... 目录1. 引言2. 项目依赖3. 配置 Redis4. 创建拦截器5. 注册拦截器6. 创建控制器8.

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

Spring Boot 中整合 MyBatis-Plus详细步骤(最新推荐)

《SpringBoot中整合MyBatis-Plus详细步骤(最新推荐)》本文详细介绍了如何在SpringBoot项目中整合MyBatis-Plus,包括整合步骤、基本CRUD操作、分页查询、批... 目录一、整合步骤1. 创建 Spring Boot 项目2. 配置项目依赖3. 配置数据源4. 创建实体类

Linux限制ip访问的解决方案

《Linux限制ip访问的解决方案》为了修复安全扫描中发现的漏洞,我们需要对某些服务设置访问限制,具体来说,就是要确保只有指定的内部IP地址能够访问这些服务,所以本文给大家介绍了Linux限制ip访问... 目录背景:解决方案:使用Firewalld防火墙规则验证方法深度了解防火墙逻辑应用场景与扩展背景: