Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现

2024-05-10 04:04

本文主要是介绍Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

什么是GAN

GAN是生成对抗网络,将会根据一个随机向量,实现数据的生成(如生成手写数字、生成文本等)。

GAN的训练过程中,需要有一个生成器G和一个鉴别器D.

生成器用于生成数据,鉴定器用于鉴定数据的准确性,其实就是在鉴别数据是人生成的还是机器生成的,因为生成器需要以假乱真。

鉴别器将会与生成器一起训练。鉴别器将会先训练,这样才有适当的能力去鉴定生成器生成数据的准确性。

鉴别器的训练过程中,需要先给它准确的数据,和通过随机向量传入生成器产生的数据(一律视为负样本),并通过损失函数对其进行训练;生成器训练过程中,会先给它一个随机向量进行前向传播,然后让鉴别器判断其正确性,并通过损失函数(不正确的数据意味着有损失)进行训练:

生成器训练过程中,需要先通过随机向量获取其结果,然后让鉴别器进行鉴别,在通过鉴别器的鉴别结果计算损失(如果鉴别器认为这是生成器生成的,则产生损失),最后更新梯度和参数:

训练过程直到生成器拟合训练集(收敛),判别器的输出总是0.5(均方误差损失函数应为0.25)为止.

形象的GAN的例子

想象一场由一位“名画伪造者”和一位“艺术鉴定家”参与的猫捉老鼠游戏。

在这个场景中,名画伪造者(即GAN中的生成器)的目标是创造出一幅足以欺骗艺术鉴定家(即GAN中的判别器)的假画。开始时,伪造者的技艺并不精湛,他制作的假画充满了破绽,很容易被鉴定家一眼识破。

然而,随着伪造者不断尝试和失败,他逐渐从每一次的失败中学习,逐渐提升了自己的技艺。他开始注意到真画的每一个细节,从笔触、色彩到构图,都尽量模仿得惟妙惟肖。每一次的失败都让他更接近成功,他制作的假画也越来越难以辨别真伪。

而艺术鉴定家也不甘示弱。他开始时能够轻易地识别出伪造者的假画,但随着伪造者技艺的提升,他也需要不断提升自己的鉴定能力。他开始深入研究真画的每一个特点,以便更准确地识别出伪造者的假画。

这个过程就像GAN中的训练过程一样。生成器不断尝试生成新的数据(在这里是假画),而判别器则不断尝试区分这些数据是真实的还是生成的。两者在相互竞争的过程中不断提升自己的能力,最终达到了一个平衡状态。

在这个例子中,名画伪造者就是GAN中的生成器,他负责生成新的数据;而艺术鉴定家则是GAN中的判别器,他负责区分数据的真伪。两者在相互竞争的过程中共同进步,使得生成的数据越来越接近真实的数据。

代码实现

本文将以基于MNIST(手写数据集)为数据集,实现一个生成手写数字的GAN模型:

首先创建models.py,用于定义判别器和生成器:

import paddle# Generator Code
class Generator(paddle.nn.Layer):def __init__(self, ):super(Generator, self).__init__()self.gen = paddle.nn.Sequential(paddle.nn.Linear(in_features=100, out_features=256),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=256, out_features=512),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=512, out_features=1024),paddle.nn.Tanh(),)def forward(self, x):x = self.gen(x)out = paddle.reshape(x,[-1,1,32,32])return out# Discriminator Code
class Discriminator(paddle.nn.Layer):def __init__(self, ):super(Discriminator, self).__init__()self.dis = paddle.nn.Sequential(paddle.nn.Linear(in_features=1024, out_features=512),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=512, out_features=256),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=256, out_features=1),paddle.nn.Sigmoid())def forward(self, x):x = paddle.reshape(x, [-1, 1024])out = self.dis(x)return out

其中,生成器将接收一个长度为100的张量(随机向量),输出一个长度为1024的张量(生成的图片);鉴别器将接收一个长度为1024的张量(图片) ,输出长度为1的张量(鉴别结果)

然后创建main.py,用于训练:

import paddle
import matplotlib.pyplot as plt
from models import Generator, Discriminator
import numpy as npdataset = paddle.vision.datasets.MNIST(mode='train',transform=paddle.vision.transforms.Compose([paddle.vision.transforms.Resize((32, 32)),paddle.vision.transforms.Normalize([0], [255])]))dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 10for epoch in range(max_epoch):now_step = 0for step, (data, label) in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = paddle.nn.functional.mse_loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = paddle.nn.functional.mse_loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个mse_loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = paddle.nn.functional.mse_loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 100 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

如果是第一次训练或不使用原有训练参数,可以将if 1改成if 0.

接下来创建use.py,用于生成图片:

import paddle
from models import Generator
import matplotlib.pyplot as pltimport paddle
from models import Generator
import matplotlib.pyplot as plt
import numpy as np# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100], 'float32')img = netG(noise)img = img.numpy()[0][0]  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0img = np.clip(img, 0, 1)  # 将img中所有小于0的元素设为0,大于1的设为1(如果需要)# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

进行多轮训练后,生成结果:

 

可以看到,它很好的生成了我们想要的图片。

GANs

但是,我们这个模型只能随机产生数字,还不能生成指定的数字(如让机器生成一个1).为了解决这个问题,我们可以针对每一个数字生成一个对应的GAN,所有这样的GAN组合起来,就是GANs. 这里不展开讲解。

参考

MNIST数据集下用Paddle框架的动态图模式玩耍经典对抗生成网络(GAN)-使用文档-PaddlePaddle深度学习平台

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN_四天搞懂生成对抗网络(一)-CSDN博客

这篇关于Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/975382

相关文章

W外链微信推广短连接怎么做?

制作微信推广链接的难点分析 一、内容创作难度 制作微信推广链接时,首先需要创作有吸引力的内容。这不仅要求内容本身有趣、有价值,还要能够激起人们的分享欲望。对于许多企业和个人来说,尤其是那些缺乏创意和写作能力的人来说,这是制作微信推广链接的一大难点。 二、精准定位难度 微信用户群体庞大,不同用户的需求和兴趣各异。因此,制作推广链接时需要精准定位目标受众,以便更有效地吸引他们点击并分享链接

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount