AIGC笔记--VQVAE模型搭建

2024-01-16 06:12
文章标签 笔记 模型 搭建 aigc vqvae

本文主要是介绍AIGC笔记--VQVAE模型搭建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1--VQVAE模型

        VAE 模型生成的内容质量不高,原因可能在于将图片编码成连续变量(映射为标准分布),然而将图片编码成离散变量可能会更好(因为现实生活中习惯用离散变量来形容事物,例如人的高矮胖瘦等都是离散的;)

        VQVAE模型的三个关键模块:EncoderDecoderCodebook

        Encoder 将输入编码成特征向量,计算特征向量与 Codebook 中 Embedding 向量的相似性(L2距离),取最相似的 Embedding 向量作为特征向量的替代,并输入到 Decoder 中进行重构输入;

        VQVAE的损失函数包括源图片和重构图片的重构损失,以及 Codebook 中量化过程的量化损失 vq_loss;

        VQ-VAE详细介绍参考:轻松理解 VQ-VAE

2--简单代码实例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass VectorQuantizer(nn.Module):def __init__(self, num_embeddings, embedding_dim, commitment_cost):super(VectorQuantizer, self).__init__()self._embedding_dim = embedding_dimself._num_embeddings = num_embeddingsself._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)self._commitment_cost = commitment_costdef forward(self, inputs):# convert inputs from BCHW -> BHWCinputs = inputs.permute(0, 2, 3, 1).contiguous()input_shape = inputs.shape# Flatten inputflat_input = inputs.view(-1, self._embedding_dim)# Calculate distancesdistances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1)- 2 * torch.matmul(flat_input, self._embedding.weight.t()))# Encodingencoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)encodings.scatter_(1, encoding_indices, 1)# Quantize and unflattenquantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)# Losse_latent_loss = F.mse_loss(quantized.detach(), inputs)  # 论文中损失函数的第三项q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项loss = q_latent_loss + self._commitment_cost * e_latent_lossquantized = inputs + (quantized - inputs).detach() # 梯度复制avg_probs = torch.mean(encodings, dim=0)perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))# convert quantized from BHWC -> BCHWreturn loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodingsclass VectorQuantizerEMA(nn.Module):def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):super(VectorQuantizerEMA, self).__init__()self._embedding_dim = embedding_dimself._num_embeddings = num_embeddingsself._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)self._embedding.weight.data.normal_()self._commitment_cost = commitment_costself.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))self._ema_w.data.normal_()self._decay = decayself._epsilon = epsilondef forward(self, inputs):# convert inputs from BCHW -> BHWCinputs = inputs.permute(0, 2, 3, 1).contiguous()input_shape = inputs.shape # B(256) H(8) W(8) C(64)# Flatten input BHWC -> BHW, Cflat_input = inputs.view(-1, self._embedding_dim)# Calculate distances 计算与embedding space中所有embedding的距离distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1)- 2 * torch.matmul(flat_input, self._embedding.weight.t()))# Encodingencoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 取最相似的embeddingencodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)encodings.scatter_(1, encoding_indices, 1) # 映射为 one-hot vector# Quantize and unflattenquantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # 根据index使用embedding space对应的embedding# Use EMA to update the embedding vectorsif self.training:self._ema_cluster_size = self._ema_cluster_size * self._decay + \(1 - self._decay) * torch.sum(encodings, 0)# Laplace smoothing of the cluster sizen = torch.sum(self._ema_cluster_size.data)self._ema_cluster_size = ((self._ema_cluster_size + self._epsilon)/ (n + self._num_embeddings * self._epsilon) * n) dw = torch.matmul(encodings.t(), flat_input)self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # 论文中公式(8)# Losse_latent_loss = F.mse_loss(quantized.detach(), inputs) # 计算encoder输出(即inputs)和decoder输入(即quantized)之间的损失loss = self._commitment_cost * e_latent_loss# Straight Through Estimatorquantized = inputs + (quantized - inputs).detach() # trick, 将decoder的输入对应的梯度复制,作为encoder的输出对应的梯度avg_probs = torch.mean(encodings, dim=0)perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))# convert quantized from BHWC -> BCHWreturn loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodingsclass Residual(nn.Module):def __init__(self, in_channels, num_hiddens, num_residual_hiddens):super(Residual, self).__init__()self._block = nn.Sequential(nn.ReLU(True),nn.Conv2d(in_channels = in_channels,out_channels = num_residual_hiddens,kernel_size = 3, stride = 1, padding = 1, bias = False),nn.ReLU(True),nn.Conv2d(in_channels = num_residual_hiddens,out_channels = num_hiddens,kernel_size = 1, stride = 1, bias = False))def forward(self, x):return x + self._block(x)class ResidualStack(nn.Module):def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):super(ResidualStack, self).__init__()self._num_residual_layers = num_residual_layersself._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)for _ in range(self._num_residual_layers)])def forward(self, x):for i in range(self._num_residual_layers):x = self._layers[i](x)return F.relu(x)class Encoder(nn.Module):def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):super(Encoder, self).__init__()self._conv_1 = nn.Conv2d(in_channels = in_channels,out_channels = num_hiddens//2,kernel_size = 4,stride = 2, padding = 1)self._conv_2 = nn.Conv2d(in_channels = num_hiddens//2,out_channels = num_hiddens,kernel_size = 4,stride = 2, padding = 1)self._conv_3 = nn.Conv2d(in_channels = num_hiddens,out_channels = num_hiddens,kernel_size = 3,stride = 1, padding = 1)self._residual_stack = ResidualStack(in_channels = num_hiddens,num_hiddens = num_hiddens,num_residual_layers = num_residual_layers,num_residual_hiddens = num_residual_hiddens)def forward(self, inputs):x = self._conv_1(inputs)x = F.relu(x)x = self._conv_2(x)x = F.relu(x)x = self._conv_3(x)return self._residual_stack(x)class Decoder(nn.Module):def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):super(Decoder, self).__init__()self._conv_1 = nn.Conv2d(in_channels=in_channels,out_channels=num_hiddens,kernel_size=3, stride=1, padding=1)self._residual_stack = ResidualStack(in_channels=num_hiddens,num_hiddens=num_hiddens,num_residual_layers=num_residual_layers,num_residual_hiddens=num_residual_hiddens)self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, out_channels=num_hiddens//2,kernel_size=4, stride=2, padding=1)self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, out_channels=3,kernel_size=4, stride=2, padding=1)def forward(self, inputs):x = self._conv_1(inputs)x = self._residual_stack(x)x = self._conv_trans_1(x)x = F.relu(x)return self._conv_trans_2(x)class Model(nn.Module):def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0):super(Model, self).__init__()self._encoder = Encoder(3, num_hiddens,num_residual_layers, num_residual_hiddens)self._pre_vq_conv = nn.Conv2d(in_channels = num_hiddens, out_channels = embedding_dim,kernel_size = 1, stride = 1)if decay > 0.0:self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)else:self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,commitment_cost)self._decoder = Decoder(embedding_dim,num_hiddens, num_residual_layers, num_residual_hiddens)def forward(self, x): # x.shape: B(256) C(3) H(32) W(32)z = self._encoder(x)z = self._pre_vq_conv(z)loss, quantized, perplexity, _ = self._vq_vae(z)x_recon = self._decoder(quantized) # decoder解码还原图像 B(256) C(3) H(32) W(32)return loss, x_recon, perplexity

完整代码参考:liujf69/VQ-VAE

3--部分细节解读:

重构损失计算:

        计算源图像和重构图像的MSE损失

vq_loss, data_recon, perplexity = self.model(data)
recon_error = F.mse_loss(data_recon, data) / self.data_variance 

VQ量化损失计算:

        inputs表示Encoder的输出,quantized是Codebook中与 inputs 最接近的向量;

# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)  # 论文中损失函数的第三项
q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
loss = q_latent_loss + self._commitment_cost * e_latent_loss

Decoder的梯度复制到Encoder中:inputs是Encoder的输出,quantized是Decoder的输入;

quantized = inputs + (quantized - inputs).detach() # 梯度复制

这篇关于AIGC笔记--VQVAE模型搭建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/611562

相关文章

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选

5分钟获取deepseek api并搭建简易问答应用

《5分钟获取deepseekapi并搭建简易问答应用》本文主要介绍了5分钟获取deepseekapi并搭建简易问答应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需... 目录1、获取api2、获取base_url和chat_model3、配置模型参数方法一:终端中临时将加

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

Mycat搭建分库分表方式

《Mycat搭建分库分表方式》文章介绍了如何使用分库分表架构来解决单表数据量过大带来的性能和存储容量限制的问题,通过在一对主从复制节点上配置数据源,并使用分片算法将数据分配到不同的数据库表中,可以有效... 目录分库分表解决的问题分库分表架构添加数据验证结果 总结分库分表解决的问题单表数据量过大带来的性能

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

鸿蒙开发搭建flutter适配的开发环境

《鸿蒙开发搭建flutter适配的开发环境》文章详细介绍了在Windows系统上如何创建和运行鸿蒙Flutter项目,包括使用flutterdoctor检测环境、创建项目、编译HAP包以及在真机上运... 目录环境搭建创建运行项目打包项目总结环境搭建1.安装 DevEco Studio NEXT IDE

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了