VQ-VAE torch 实现

2023-10-27 21:36
文章标签 实现 torch vq vae

本文主要是介绍VQ-VAE torch 实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • model
  • main

model

import torch
import torch.nn as nnclass ResidualBlock(nn.Module):def __init__(self, dim):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)self.conv2 = nn.Conv2d(dim, dim, 1)def forward(self, x):tmp = self.relu(x)tmp = self.conv1(tmp)tmp = self.relu(tmp)tmp = self.conv2(tmp)return x + tmpclass VQVAE(nn.Module):def __init__(self, input_dim, dim, n_embedding):super().__init__()self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, 4, 2, 1),nn.ReLU(), nn.Conv2d(dim, dim, 4, 2, 1),nn.ReLU(), nn.Conv2d(dim, dim, 3, 1, 1),ResidualBlock(dim), ResidualBlock(dim))self.vq_embedding = nn.Embedding(n_embedding, dim)self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding,1.0 / n_embedding)self.decoder = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1),ResidualBlock(dim), ResidualBlock(dim),nn.ConvTranspose2d(dim, dim, 4, 2, 1), nn.ReLU(),nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))self.n_downsample = 2def forward(self, x):# encodeze = self.encoder(x)# ze: [N, C, H, W]# embedding [K, C]embedding = self.vq_embedding.weight.dataN, C, H, W = ze.shapeK, _ = embedding.shapeembedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)# make C to the second dimzq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)# stop gradientdecoder_input = ze + (zq - ze).detach()# decodex_hat = self.decoder(decoder_input)return x_hat, ze, zq@torch.no_grad()def encode(self, x):ze = self.encoder(x)embedding = self.vq_embedding.weight.data# ze: [N, C, H, W]# embedding [K, C]N, C, H, W = ze.shapeK, _ = embedding.shapeembedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)return nearest_neighbor@torch.no_grad()def decode(self, discrete_latent):zq = self.vq_embedding(discrete_latent).permute(0, 3, 1, 2)x_hat = self.decoder(zq)return x_hat# Shape: [C, H, W]def get_latent_HW(self, input_shape):C, H, W = input_shapereturn (H // 2**self.n_downsample, W // 2**self.n_downsample)

main

def train_vqvae(model: VQVAE,img_shape=None,device='cuda',ckpt_path='dldemos/VQVAE/model.pth',batch_size=64,dataset_type='MNIST',lr=1e-3,n_epochs=100,l_w_embedding=1,l_w_commitment=0.25):print('batch size:', batch_size)dataloader = get_dataloader(dataset_type,batch_size,img_shape=img_shape,use_lmdb=USE_LMDB)model.to(device)model.train()optimizer = torch.optim.Adam(model.parameters(), lr)mse_loss = nn.MSELoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x in dataloader:current_batch_size = x.shape[0]x = x.to(device)x_hat, ze, zq = model(x)l_reconstruct = mse_loss(x, x_hat)l_embedding = mse_loss(ze.detach(), zq)l_commitment = mse_loss(ze, zq.detach())loss = l_reconstruct + \l_w_embedding * l_embedding + l_w_commitment * l_commitmentoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), ckpt_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')print('Done')

这篇关于VQ-VAE torch 实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/288553

相关文章

Python实现终端清屏的几种方式详解

《Python实现终端清屏的几种方式详解》在使用Python进行终端交互式编程时,我们经常需要清空当前终端屏幕的内容,本文为大家整理了几种常见的实现方法,有需要的小伙伴可以参考下... 目录方法一:使用 `os` 模块调用系统命令方法二:使用 `subprocess` 模块执行命令方法三:打印多个换行符模拟

SpringBoot+EasyPOI轻松实现Excel和Word导出PDF

《SpringBoot+EasyPOI轻松实现Excel和Word导出PDF》在企业级开发中,将Excel和Word文档导出为PDF是常见需求,本文将结合​​EasyPOI和​​Aspose系列工具实... 目录一、环境准备与依赖配置1.1 方案选型1.2 依赖配置(商业库方案)二、Excel 导出 PDF

Python实现MQTT通信的示例代码

《Python实现MQTT通信的示例代码》本文主要介绍了Python实现MQTT通信的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 安装paho-mqtt库‌2. 搭建MQTT代理服务器(Broker)‌‌3. pytho

使用zip4j实现Java中的ZIP文件加密压缩的操作方法

《使用zip4j实现Java中的ZIP文件加密压缩的操作方法》本文介绍如何通过Maven集成zip4j1.3.2库创建带密码保护的ZIP文件,涵盖依赖配置、代码示例及加密原理,确保数据安全性,感兴趣的... 目录1. zip4j库介绍和版本1.1 zip4j库概述1.2 zip4j的版本演变1.3 zip4

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

Spring Boot 结合 WxJava 实现文章上传微信公众号草稿箱与群发

《SpringBoot结合WxJava实现文章上传微信公众号草稿箱与群发》本文将详细介绍如何使用SpringBoot框架结合WxJava开发工具包,实现文章上传到微信公众号草稿箱以及群发功能,... 目录一、项目环境准备1.1 开发环境1.2 微信公众号准备二、Spring Boot 项目搭建2.1 创建

IntelliJ IDEA2025创建SpringBoot项目的实现步骤

《IntelliJIDEA2025创建SpringBoot项目的实现步骤》本文主要介绍了IntelliJIDEA2025创建SpringBoot项目的实现步骤,文中通过示例代码介绍的非常详细,对大家... 目录一、创建 Spring Boot 项目1. 新建项目2. 基础配置3. 选择依赖4. 生成项目5.

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并