GAN生成抛物线

2023-10-10 08:40
文章标签 生成 gan 抛物线

本文主要是介绍GAN生成抛物线,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文主要讲解GAN的原理以及一个小实战,利用GAN生成抛物线,首先我们看一下GAN的原理。

GAN是2014年提出来的,他的原理可以这样理解,他有一个生成器和一个判别器,生成器是不断的生成数据,判别器的原理是将真实图片和生成器制作的数据区分开来,目的就是鉴别生成器生成的数据是假的,把原始数据判定为真。为生成器相反,他的目的就是源源不断的生成数据,让判别器无法分辨真假,从而以假乱真。常看到的一个例子就是坏人制作假币,警察差查假币。最终达到判别器无法分辨谁真谁假,也就是对于生成数据,它判断的为假的概率也是0.5,对于一个真实数据,它判断为真的概率也是0.5,这是最理想的状态。

任何一个神经网络模型都有损失函数,对于GAN模型,自然也不例外,他也有损失函数,因为他有丄生成网络和判别网络之分,所以当然是两个损失函数的,我们先看判别网路的损失函数,

min (-log(D(x)) -log(1-D(G(x))))
这个就是判别网络的损失函数,这样理解,我们把真实数据判定为1是对的,把生成网络生成的数据判定为0是对的,也就是我们把真实数据判定为1那么损失函数越小,把生成数据判定为0,判别网络损失函数越小,然后我们再看损失函数,不难发现,要想损失函数越小,D(x)越接近1越好,这里x表示真实数据输入判别网络,可以看出D(G(x))越接近0损失越小,其中G(x)表示生成网络的输出输入判别网络。这也很符合我们平时的理解。再看生成网络的损失函数

min(-log(D(G(x)))
从生成网络我们可以看出,生成网络的目的是让D(G(x))越接近1越好,越就是D(G(x))越接近1,损失函数越小,这就和判别网络矛盾了,那就形成了竞争。

有了这些理论,我们再看实际代码实现生成抛物线的对抗神经网络:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltnp.random.seed(2018)
tf.set_random_seed(2018)def real_data(num):#x = np.random.uniform(-10,10,[num,1])x = np.linspace(-3,3, num) + np.random.random(num) * 0.01x = x.reshape([-1,1])#sample = np.sin(x) + 1sample = x**2 +1return x,sampledef fake_data(num):x = np.linspace(-3,3, num) + np.random.random(num) * 0.01return x.reshape([-1,1])batch_size = 64
iters = 2000
#hidden_units = batch_size//2
alpha = 0.01
lr = 0.0001
#gen_num = 10
# out_dim = Nonedef generator(inputs,alpha,reuse=False):# out_dim 表示输出的大小,最后一层全连接层输出的大小,所以和batch_size大小一样# hidden_units隐藏层神经元的个数# alpha Leaky ReLU激活函数的参数# reuse 是否重用参数变量with tf.variable_scope('generator',reuse=reuse) as scope:hidden_1 = tf.layers.dense(inputs, 64, activation=None)ac1 = tf.maximum(alpha * hidden_1, hidden_1)#ac1 = tf.nn.tanh( hidden_1)bn1 = tf.layers.batch_normalization(ac1)hidden_2 = tf.layers.dense(bn1, 128, activation=None)ac2 = tf.maximum(alpha * hidden_2, hidden_2)#ac2 = tf.nn.tanh(hidden_2)bn2 = tf.layers.batch_normalization(ac2)hidden_3 = tf.layers.dense(bn2, 256, activation=None)ac3 = tf.maximum(alpha *hidden_3, hidden_3)#ac3 = tf.nn.tanh(ac2)bn3 = tf.layers.batch_normalization(ac3)out = tf.layers.dense(bn3,1,activation=None)return outdef discriminator(discr_input,alpha,reuse=False,name='discriminator'):# discr_input 判别器的输入# alpha Leaky ReLU激活函数的参数# hidden_units 隐藏层的神经元个数# reuse 是否重用变量with tf.variable_scope(name,reuse=reuse) as scope:hidden_1 = tf.layers.dense(discr_input,units=64,activation=None)#ac1 = tf.maximum(alpha*hidden_1,hidden_1)ac1 = tf.nn.tanh(hidden_1)hidden_2 = tf.layers.dense(ac1, units=128, activation=None)#ac2 = tf.maximum(alpha * hidden_2, hidden_2)ac2 = tf.nn.tanh(hidden_2)hidden_3 = tf.layers.dense(ac2, units=128, activation=None)#ac3 = tf.maximum(alpha * hidden_3, hidden_3)ac3 = tf.nn.tanh(hidden_3)logits = tf.layers.dense(ac3,1,activation=None)out = tf.nn.sigmoid(logits)return logits,out
def plot_data(gen_x,gen_y):x_r,y_r = real_data(64)plt.scatter(x_r, y_r, label='real data')plt.scatter(gen_x,gen_y, label='generated data')plt.title('GAN')plt.xlabel('x')plt.ylabel('y')plt.legend()plt.show()with tf.name_scope('gen_input') as scope:gen_input = tf.placeholder(dtype=tf.float32,shape=[None,1],name='gen_input')
with tf.name_scope('discriminator_input') as scope:real_input = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='real_input')out_gen = generator(gen_input,alpha,reuse=False)real_logits,label_real = discriminator(real_input,alpha,reuse=False)
logits_gen,label_fake = discriminator(out_gen,alpha,reuse=True)with tf.name_scope('discr_train') as scope:train_input = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='train_input')
train_disc = discriminator(train_input,alpha,reuse=False,name='train_dis')
para = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='train_dis')
train_loss = tf.reduce_mean(tf.square(train_disc-train_input))#with tf.Session() as sess:
#    sess.run(tf.global_variables_initializer())with tf.name_scope('metrics') as name:loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_gen)*0.99,logits=logits_gen))loss_d_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_gen),logits=logits_gen))loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logits)*0.99, logits=real_logits))loss_d = loss_d_fake+loss_d_realvar_list_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator')var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='discriminator')d_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_d,var_list=var_list_d)g_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_g,var_list=var_list_g)with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver = tf.train.Saver()writer = tf.summary.FileWriter('./graph/gan',sess.graph)# for i in range(1000):#     _, real = real_data(batch_size)#     _ = sess.run(train_loss,feed_dict={train_input:real})# train_weig = sess.run(para)# for i in range(len((var_list_d))):#     sess.run(var_list_d[i].assign(train_weig[i]))for iter in range(iters):_,real = real_data(batch_size)fake = fake_data(batch_size)_,train_loss_d = sess.run([d_optimizer,loss_d],feed_dict={real_input:real,gen_input: fake})_, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})fake = fake_data(batch_size)_, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})fake = fake_data(batch_size)_, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})if iter % 200 == 0:print(train_loss_d)print(train_loss_g)gen_x = np.linspace(-3,3,500).reshape([-1,1])gen_y = sess.run(out_gen,feed_dict={gen_input:gen_x})plot_data(gen_x, gen_y)saver.save(sess, "./checkpoints/gen")writer.close()
下面展示一个训练过程的图像:

在做这个的时候,我试图生成正弦曲线,但是效果比较差,我猜测是和正弦函数泰勒展开有关系,又知道的希望提点一下。






这篇关于GAN生成抛物线的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/179335

相关文章

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

hdu 1102 uva 10397(最小生成树prim)

hdu 1102: 题意: 给一个邻接矩阵,给一些村庄间已经修的路,问最小生成树。 解析: 把已经修的路的权值改为0,套个prim()。 注意prim 最外层循坏为n-1。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstri

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

poj 3723 kruscal,反边取最大生成树。

题意: 需要征募女兵N人,男兵M人。 每征募一个人需要花费10000美元,但是如果已经招募的人中有一些关系亲密的人,那么可以少花一些钱。 给出若干的男女之间的1~9999之间的亲密关系度,征募某个人的费用是10000 - (已经征募的人中和自己的亲密度的最大值)。 要求通过适当的招募顺序使得征募所有人的费用最小。 解析: 先设想无向图,在征募某个人a时,如果使用了a和b之间的关系

Thymeleaf:生成静态文件及异常处理java.lang.NoClassDefFoundError: ognl/PropertyAccessor

我们需要引入包: <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</artifactId></dependency><dependency><groupId>org.springframework</groupId><artifactId>sp

前端-06-eslint9大变样后,如何生成旧版本的.eslintrc.cjs配置文件

目录 问题解决办法 问题 最近在写一个vue3+ts的项目,看了尚硅谷的视频,到了配置eslintrc.cjs的时候我犯了难,因为eslint从9.0之后重大更新,跟以前完全不一样,但是我还是想用和老师一样的eslintrc.cjs文件,该怎么做呢? 视频链接:尚硅谷Vue项目实战硅谷甄选,vue3项目+TypeScript前端项目一套通关 解决办法 首先 eslint 要