AE——重构数字(Pytorch+mnist)

2024-03-30 17:28
文章标签 pytorch 重构 数字 mnist ae

本文主要是介绍AE——重构数字(Pytorch+mnist),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、简介

  • AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。 

2、代码

  • import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision# 设置种子和其他配置
    seed = 42  # 设置随机种子
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False  # 禁用 cuDNN 的自动寻找最佳算法
    torch.backends.cudnn.deterministic = True  # 设置 cuDNN 为确定性模式# 设置批大小、学习周期和学习率
    batch_size = 512
    epochs = 30
    learning_rate = 1e-3# 载入 MNIST 数据集中的图片进行训练
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量train_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=True, transform=transform, download=True
    )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 Truetrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
    )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据# 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
    class AE(nn.Module):def __init__(self, **kwargs):super().__init__()self.encoder_hidden_layer = nn.Linear(in_features=kwargs["input_shape"], out_features=128)  # 编码器隐藏层self.encoder_output_layer = nn.Linear(in_features=128, out_features=128)  # 编码器输出层self.decoder_hidden_layer = nn.Linear(in_features=128, out_features=128)  # 解码器隐藏层self.decoder_output_layer = nn.Linear(in_features=128, out_features=kwargs["input_shape"])  # 解码器输出层# 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成def forward(self, features):activation = self.encoder_hidden_layer(features)activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值code = self.encoder_output_layer(activation)code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内activation = self.decoder_hidden_layer(code)activation = torch.relu(activation)activation = self.decoder_output_layer(activation)reconstructed = torch.sigmoid(activation)return reconstructed# 在使用定义的 AE 类之前,有以下事情要做:
    # 配置要在哪个设备上运行
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 建立 AE 模型并载入到 CPU 设备
    model = AE(input_shape=784).to(device)# Adam 优化器,学习率 10e-3
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 使用均方误差(MSE)损失函数
    criterion = nn.MSELoss()# 在CPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
    # 训练:
    for epoch in range(epochs):loss = 0for batch_features, _ in train_loader:# 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备batch_features = batch_features.view(-1, 784).to(device)# 梯度设置为 0,因为 torch 会累加梯度optimizer.zero_grad()# 计算重构outputs = model(batch_features)# 计算训练重建损失train_loss = criterion(outputs, batch_features)# 计算累积梯度train_loss.backward()# 根据当前梯度更新参数optimizer.step()# 将小批量训练损失加到周期损失中loss += train_loss.item()# 计算每个周期的训练损失loss = loss / len(train_loader)# 显示每个周期的训练损失print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))if __name__ == '__main__':# 用训练过的自编码器提取一些测试用例来重构test_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=False, transform=transform, download=True)  # 加载 MNIST 测试数据集test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False)  # 创建一个测试数据加载器test_examples = None# 通过循环遍历测试数据加载器,获取一个批次的图像数据with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据batch_features = batch_features[0]  # 获取当前批次的图像数据test_examples = batch_features.view(-1, 784).to(device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像break# 试着用训练过的自编码器重建一些测试图像with torch.no_grad():number = 10  # 设置要显示的图像数量plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)for index in range(number):  # 遍历要显示的图像数量# 显示原始图ax = plt.subplot(2, number, index + 1)plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# 显示重构图ax = plt.subplot(2, number, index + 1 + number)plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.savefig('reconstruction_results.png')  # 保存图像plt.show()

这篇关于AE——重构数字(Pytorch+mnist)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/862112

相关文章

捷瑞数字业绩波动性明显:关联交易不低,募资必要性遭质疑

《港湾商业观察》施子夫 5月22日,山东捷瑞数字科技股份有限公司(以下简称,捷瑞数字)及保荐机构国新证券披露第三轮问询的回复,继续推进北交所上市进程。 从2023年6月递表开始,监管层已下发三轮审核问询函,关注到捷瑞数字存在同业竞争、关联交易、募资合理性、期后业绩波动等焦点问题。公司的上市之路多少被阴影笼罩。​ 业绩波动遭问询 捷瑞数字成立于2000年,公司是一家以数字孪生驱动的工

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

数据时代的数字企业

1.写在前面 讨论数据治理在数字企业中的影响和必要性,并介绍数据治理的核心内容和实践方法。作者强调了数据质量、数据安全、数据隐私和数据合规等方面是数据治理的核心内容,并介绍了具体的实践措施和案例分析。企业需要重视这些方面以实现数字化转型和业务增长。 数字化转型行业小伙伴可以加入我的星球,初衷成为各位数字化转型参考库,星球内容每周更新 个人工作经验资料全部放在这里,包含数据治理、数据要

剑指offer(C++)--和为S的两个数字

题目 输入一个递增排序的数组和一个数字S,在数组中查找两个数,使得他们的和正好是S,如果有多对数字的和等于S,输出两个数的乘积最小的。 class Solution {public:vector<int> FindNumbersWithSum(vector<int> array,int sum) {vector<int> result;int len = array.size();if(

剑指offer(C++)--数组中只出现一次的数字

题目 一个整型数组里除了两个数字之外,其他的数字都出现了两次。请写程序找出这两个只出现一次的数字。 class Solution {public:void FindNumsAppearOnce(vector<int> data,int* num1,int *num2) {int len = data.size();if(len<2)return;int one = 0;for(int i

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

神经网络第四篇:推理处理之手写数字识别

到目前为止,我们已经介绍完了神经网络的基本结构,现在用一个图像识别示例对前面的知识作整体的总结。本专题知识点如下: MNIST数据集图像数据转图像神经网络的推理处理批处理  MNIST数据集          mnist数据图像 MNIST数据集由0到9的数字图像构成。像素取值在0到255之间。每个图像数据都相应地标有“7”、“2”、“1”等数字标签。MNIST数据集中,

江西电信联合实在智能举办RPA数字员工培训班,培养“人工智能+”电信人才

近日,江西电信与实在智能合作的2024年数字员工开发应用培训班圆满闭幕。包括省公司及11个分公司的核心业务部门,超过40名学员积极报名参与此次培训,江西电信企业信息化部门总监徐建军出席活动并致辞,风控支撑室主任黄剑主持此次培训活动。 在培训会开幕仪式上,徐建军强调,科创是电信企业发展的核心动力,学习RPA技术是实现数字化转型的关键,他阐述了RPA在提高效率、降低成本和优化资源方面的价值,并鼓励学

LeetCode —— 只出现一次的数字

只出现一次的数字 I  本题依靠异或运算符的特性,两个相同数据异或等于0,数字与0异或为本身即可解答。代码如下: class Solution {public:int singleNumber(vector<int>& nums) {int ret = 0;for (auto e : nums){ret ^= e;}return ret;}};  只出现一次的数字 II

人工智能在数字病理切片虚拟染色以及染色标准化领域的研究进展|顶刊速递·24-06-23

小罗碎碎念 本期推文主题:人工智能在数字病理切片虚拟染色以及染色标准化领域的研究进展 这一期的推文是我发自内心觉得为数不多,特别宝贵的一篇推文,原因很简单——可参考的文献相对较少&方向非常具有研究意义&现在不卷。 数字病理方向的老师/同学应该清楚,不同中心提供的切片,染色方案是存在差异的,并且还存在各种质量问题,所以我们在数据预处理的时候,通常会先对切片的质量执行一遍筛选,然后再进行染