AE——重构数字(Pytorch+mnist)

2024-03-30 17:28
文章标签 pytorch 重构 数字 mnist ae

本文主要是介绍AE——重构数字(Pytorch+mnist),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、简介

  • AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。 

2、代码

  • import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision# 设置种子和其他配置
    seed = 42  # 设置随机种子
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False  # 禁用 cuDNN 的自动寻找最佳算法
    torch.backends.cudnn.deterministic = True  # 设置 cuDNN 为确定性模式# 设置批大小、学习周期和学习率
    batch_size = 512
    epochs = 30
    learning_rate = 1e-3# 载入 MNIST 数据集中的图片进行训练
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量train_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=True, transform=transform, download=True
    )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 Truetrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
    )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据# 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
    class AE(nn.Module):def __init__(self, **kwargs):super().__init__()self.encoder_hidden_layer = nn.Linear(in_features=kwargs["input_shape"], out_features=128)  # 编码器隐藏层self.encoder_output_layer = nn.Linear(in_features=128, out_features=128)  # 编码器输出层self.decoder_hidden_layer = nn.Linear(in_features=128, out_features=128)  # 解码器隐藏层self.decoder_output_layer = nn.Linear(in_features=128, out_features=kwargs["input_shape"])  # 解码器输出层# 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成def forward(self, features):activation = self.encoder_hidden_layer(features)activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值code = self.encoder_output_layer(activation)code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内activation = self.decoder_hidden_layer(code)activation = torch.relu(activation)activation = self.decoder_output_layer(activation)reconstructed = torch.sigmoid(activation)return reconstructed# 在使用定义的 AE 类之前,有以下事情要做:
    # 配置要在哪个设备上运行
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 建立 AE 模型并载入到 CPU 设备
    model = AE(input_shape=784).to(device)# Adam 优化器,学习率 10e-3
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 使用均方误差(MSE)损失函数
    criterion = nn.MSELoss()# 在CPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
    # 训练:
    for epoch in range(epochs):loss = 0for batch_features, _ in train_loader:# 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备batch_features = batch_features.view(-1, 784).to(device)# 梯度设置为 0,因为 torch 会累加梯度optimizer.zero_grad()# 计算重构outputs = model(batch_features)# 计算训练重建损失train_loss = criterion(outputs, batch_features)# 计算累积梯度train_loss.backward()# 根据当前梯度更新参数optimizer.step()# 将小批量训练损失加到周期损失中loss += train_loss.item()# 计算每个周期的训练损失loss = loss / len(train_loader)# 显示每个周期的训练损失print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))if __name__ == '__main__':# 用训练过的自编码器提取一些测试用例来重构test_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=False, transform=transform, download=True)  # 加载 MNIST 测试数据集test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False)  # 创建一个测试数据加载器test_examples = None# 通过循环遍历测试数据加载器,获取一个批次的图像数据with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据batch_features = batch_features[0]  # 获取当前批次的图像数据test_examples = batch_features.view(-1, 784).to(device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像break# 试着用训练过的自编码器重建一些测试图像with torch.no_grad():number = 10  # 设置要显示的图像数量plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)for index in range(number):  # 遍历要显示的图像数量# 显示原始图ax = plt.subplot(2, number, index + 1)plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# 显示重构图ax = plt.subplot(2, number, index + 1 + number)plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.savefig('reconstruction_results.png')  # 保存图像plt.show()

这篇关于AE——重构数字(Pytorch+mnist)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/862112

相关文章

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的