深度学习实战3--GAN:基础手写数字对抗生成

2024-08-30 23:44

本文主要是介绍深度学习实战3--GAN:基础手写数字对抗生成,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本节目标

1.看懂GAN  基础架构的代码;

2.重点是GAN  的损失函数的构成;

3.理解如何从 GAN 修改成CGAN;

4.尝试复现本章实战任务

任务描述

        GAN 的任务是生成,用两个模型相互对抗,来增强生成模型的效果。此处准备的数据集是MNIST手写数字,希望生成类似的手写数字的图像。

判别器和生成器:生成器 G 是创造者,负责生成新的数据实例,而判别器 D 是鉴别者,负责评估数据实例的真伪。两者相互竞争,推动对方不断进步,从而提高生成数据的质量。

注意:BCE 是Binary_Cross_Entropy的缩写,可以理解为二分类问题。GAN 的任务是生成,用两个模型相互对抗,来增强生成模型的效果。那么 CGAN  就是 给定条件进行指定数字的生成。

以下内容是重点:

(1)GAN 的损失函数与BCE之间的转换;

(2)GAN 的判别器D 和生成器G 模型的输入输出;

(3)GAN 如何转化成CGAN;

(4)CGAN 中窥视到GAN 结构似乎有损害多样性的缺点

import torch #使用import语句时,要访问torch模块中的函数或类,你需要使用torch.前缀
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt#可以提高代码的可读性,因为它直接表明了你正在使用的是来自哪个模块的特定部分
import torchvision #torchvision则提供了图像处理和加载数据集的工具
from torchvision import transforms#使用from torch import nn后,可以直接使用nn而不需要torch.nn前缀,from也可避免冲突,因为是导入的特定模块
#数据归一化,它是数据预处理中的一种常用技术,目的是将数据调整到一个统一的尺度或范围内,以便于不同特征之间的比较和计算
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)
])
'''这个变换将PIL图像或Numpy数组转换为
torch.FloatTensor类型,并将数值范围从[0, 255]压缩到[0.0, 1.0]。
这是将数据从整数格式转换为浮点格式,以便于神经网络处理。'''
''''可能需要调整transforms.Normalize中
的均值和标准差参数,均值0.5,方差0.5,以匹配你的数据的实际范围'''
#加载内置数据集
train_ds=torchvision.datasets.MNIST('data',train=True,transform=transform,download=True)#PyTorch会从其官方源或者数据集的原始来源下载数据集
#创建一个数据加载器,它用于在训练过程中批量地加载数据
#每个批次包含64个样本。
#shuffle=True: 这个参数决定是否在每个epoch开始时对数据进行打乱(洗牌)。
#设置为True可以确保数据在每个epoch中以随机顺序加载,这有助于模型训练的泛化能力。
dataloader=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
#返回一个批次的数据
imgs,_=next(iter(dataloader))
#生成器
#在PyTorch中,nn.Module 的子类需要在构造函数中调用 super().__init__() 来确保正确初始化
class Generator(nn.Module):def __init__ (self):super(Generator,self).__init__()self.Linear=nn.Sequential(nn.Linear(100,256),## 将100维的输入映射到256维nn.Tanh(),## 使用双曲正切激活函数nn.Linear(256,512),# 将256维映射到512维nn.Tanh(),# 再次使用双曲正切激活函数nn.Linear(512,28*28),# 将512维映射到28*28维,适合MNIST图像nn.Tanh())#用于指定模型的前向传播过程'''这一行使用view方法改变了张量x的形状。view方法用于重塑张量而不改变其数据。这里的-1是一个特殊的参数,表示自动计算该维度的大小以保持总元素数量不变。因此,view(-1, 28, 28)将x重塑为一个形状为[batch_size, 28, 28]的三维张量,其中batch_size是输入数据中的批次大小。'''def forward(self,x):x=self.Linear(x)## 将输入数据x通过self.linear中的层进行处理x=x.view(-1,28,28)return x#方法返回处理后的张量x,它现在是一个具有28x28像素的二维图像张量
#辨别器,Discriminator 的类,目的是区分输入数据是真实的还是由生成器生成的
class Discriminator(nn.Module):def __init__(self):super(Discriminator,  self).__init__()# 创建了一个 nn.Sequential 容器,这是一个有序的容器,可以包含多个模块,它们将按顺序被应用#定义了一个线性层,它将输入特征从 28*28(即一个 28x28 的图像展平后的维度)转换为 512 维。#LeakyReLU 激活函数,它允许负值通过,以解决传统 ReLU 激活函数的“死亡ReLU”问题#激活函数的某些神经元可能会停止激活,即它们的输入值永远不会变为负数,导致这些神经元的梯度永久为零,从而不再更新。# 这会导致模型的某些部分不再学习,影响模型的性能。#定义了一个线性层,它将输入特征从 28*28(即一个 28x28 的图像展平后的维度)转换为 512 维。#定义了第三个线性层,将特征维度从 256 压缩到 1 维。这个输出通常代表判别器对输入数据是真实还是假数据的判断#添加了一个 Sigmoid 激活函数,它将线性层的输出压缩到 0 到 1 之间,通常用于二分类问题的概率输出。self.linear=nn.Sequential(nn.Linear(28*28,512),nn.LeakyReLU(),nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.linear(x)return x
device='cuda' if torch.cuda.is_available() else 'cpu'
if device=='cuda':print('using cuda:',torch.cuda.get_device_name(0))#用于获取第一个CUDA设备的名称
Gen=Generator().to(device)#创建了一个 Generator 类的实例,并将其移动到之前定义的 device 上。这个 device 可能是一个 GPU 或cpu设备
#使用 .to(device) 是为了确保模型的参数和计算都在指定的设备上进行,这样可以利用 GPU 加速训练过程
Dis=Discriminator().to(device)#实例化 Discriminator 类,创建一个判别器对象
d_optim=torch.optim.Adam(Dis.parameters(),lr=0.001)#这行代码创建了一个Adam优化器实例,用于更新名为 Dis 的判别器网络的参数。
g_optim=torch.optim.Adam(Gen.parameters(),lr=0.001)#这行代码创建了另一个Adam优化器实例,用于更新名为 Gen 的生成器网络的参数。
'''学习率的值通常表示为一个介于0和1之间的小数
如果学习率设置得太高,可能会导致模型在训练过程中不稳定,甚至发散,因为每次更新的步长太大,可能会越过损失函数的最小点
如果学习率设置得太低,模型的收敛速度会很慢,因为每次更新的步长太小,需要更多的迭代次数才能达到最小点
'''
loss_function=torch.nn.BCELoss()#用来定义二元交叉熵损失函数
def gen_img_plot(model,test_input):prediction=np.squeeze(model(test_input).detach().cpu().numpy())#生成器模型 model(test_input) 生成图像,通过 .detach().cpu().numpy() #将生成的PyTorch张量转换为NumPy数组并移至CPU,然后使用 np.squeeze() 移除单维度以简化数组形状,为绘图准备数据。fig=plt.figure(figsize=(4,4))#创建一个新的matplotlib图形对象,设置图形的大小为4x4英寸for i in range(prediction.shape[0]):#prediction.shape[0] 表示生成的图像数量plt.subplot(4,4,i+1)plt.imshow((prediction[i]+1)/2)#显示当前图像。由于生成的图像数据可能在 [-1, 1] 范围内#这里通过 (prediction[i]+1)/2 将其规范化到 [0, 1] 范围plt.axis('off')#用于关闭matplotlib图表中的坐标轴的函数plt.show()
test_input=torch.randn(16,100,device=device)
D_loss=[]#用于存储判别器(Discriminator)在训练过程中的损失值
G_loss=[]
for epoch in range(20):#每次迭代称为一个epochd_epoch_loss=0#初始化判别器的累积损失g_epoch_loss=0#初始化生成器的累积损失count=len(dataloader)#计算数据加载器(dataloader)中的批次数量。for step,(img,_) in enumerate(dataloader):# 遍历数据加载器中的每个批次img=img.to(device)#将图像数据移动到指定的设备(例如GPU)size=img.size(0)#获取当前批次的大小random_noise=torch.randn(size,100,device=device)#生成随机噪声,用作生成器的输入,"噪声"通常指的是随机生成的数据d_optim.zero_grad()#清除判别器的梯度real_output=Dis(img)#使用判别器对真实图像进行判断d_real_loss=loss_function(real_output,torch.ones_like(real_output))#计算判别器对真实图像的损失。d_real_loss.backward()#对真实图像的损失进行反向传播gen_img=Gen(random_noise)#使用生成器生成假图像fake_output=Dis(gen_img.detach())#使用判别器对生成的假图像进行判断d_fake_loss=loss_function(fake_output,  #计算判别器对假图像的损失torch.zeros_like(fake_output))d_fake_loss.backward()#对假图像的损失进行反向传播d_loss=d_real_loss+d_fake_loss#计算判别器的总损失d_optim.step()# 更新判别器的参数g_optim.zero_grad()#清除生成器的优化器中的梯度。这是每次参数更新前的标准步骤,用于防止梯度累加。fake_output=Dis(gen_img)#将生成器生成的图像(gen_img)传递给判别器(Dis),以获取判别器对假图像的判断结果。g_loss=loss_function(fake_output,torch.ones_like(fake_output))#生成器的损失计算是至关重要的,因为它指导生成器如何改进以生成更逼真的图像g_loss.backward()#计算生成器损失的反向传播g_optim.step()#更新生成器的参数。这一步使用优化器(如SGD或Adam)根据反向传播计算得到的梯度来更新生成器的权重with torch.no_grad():d_epoch_loss+=d_loss#累加判别器的损失g_epoch_loss+= g_loss#累加生成器的损失with torch.no_grad():#这个上下文管理器指示 PyTorch 在这个代码块中不计算梯度。#这通常用于推理或评估阶段,以减少内存使用并提高性能d_epoch_loss/=count# 将判别器的累积损失除以批次数量 count,以计算整个epoch的平均损失。g_epoch_loss/=count#将生成器的累积损失除以批次数量 count,以计算整个epoch的平均损失。D_loss.append(d_epoch_loss)#将计算得到的判别器平均损失添加到 D_loss 列表中,用于记录每个epoch的损失。G_loss.append(g_epoch_loss)#将计算得到的生成器平均损失添加到 G_loss 列表中,用于记录每个epoch的损失。print('Epoch:',epoch+1)# 打印当前的epoch编号,epoch+1 因为 epoch 从0开始计数,而通常人们习惯从1开始计数。gen_img_plot(Gen,test_input)#绘制生成器的输出图像

缩进错误很难分清,代码又一样,写的时候需要仔细看清,找了好久才发现这个错误。

这篇关于深度学习实战3--GAN:基础手写数字对抗生成的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1122271

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma