[PyTorch][chapter 57][WGAN-GP 代码实现]

2023-10-08 12:29

本文主要是介绍[PyTorch][chapter 57][WGAN-GP 代码实现],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言:

 下图为WGAN 的效果图:

  绿色为真实数据的分布: 8个高斯分布

  红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 损失函数 增加了penalty,使用Adam

 Wasserstein GAN
1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
 


一  简介

    1.1 模型结构

 1.2 伪代码

      


二  wgan.py

 主要变化:

    Generator 中 去掉了之前的logit 函数

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023@author: chengxf2
"""import torch
from   torch import nn#生成器模型
h_dim = 400
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()# z: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear( h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2))def forward(self, z):output = self.net(z)return output#鉴别器模型
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()hDim=400# x: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, 1),)def forward(self, x):#x:[batch,1]output = self.net(x)out = output.view(-1)return out

三 main.py

  主要变化:

    损失函数中增加了gradient_penalty

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023@author: chengxf2
"""import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autogradh_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def weights_init(net):if isinstance(net, nn.Linear):# net.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(net.weight)net.bias.data.fill_(0)def data_generator():"""8- gaussian destributionReturns-------None."""scale = 2a = np.sqrt(2.0)centers =[(1,0),(-1,0),(0,1),(0,-1),(1/a,1/a),(1/a,-1/a),(-1/a, 1/a),(-1/a,-1/a)]centers = [(scale*x, scale*y) for x,y in centers]while True:dataset =[]for i in range(batchsz):point = np.random.randn(2)*0.02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /=a#生成器函数是一个特殊的函数,可以返回一个迭代器yield datasetdef generate_image(D, G, xr, epoch):      #xr表示真实的sample"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))             # (16384, 2)x = y = np.linspace(-RANGE, RANGE, N_POINTS)N = len(x)# draw contourwith torch.no_grad():points = torch.Tensor(points)      # [16384, 2]disc_map = D(points).cpu().numpy() # [16384]plt.contour(x, y, disc_map.reshape((N, N)).transpose())#plt.clabel(cs, inline=1, fontsize=10)plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2)                 # [b, 2]samples = G(z).cpu().numpy()                # [b, 2]plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))def gradient_penalty(D, xr,xf):#[b,1]t =  torch.rand(batchsz, 1).to(device)       #[b,1]=>[b,2]  保证每个sample t 相同t =  t.expand_as(xr)#sample penalty interpoation [b,2]mid = t*xr +(1-t)*xfmid.requires_grad_()pred = D(mid) #[256]'''grad_outputs:   如果outputs 是向量,则此参数必须写retain_graph:  True 则保留计算图, False则释放计算图create_graph: 若要计算高阶导数,则必须选为Trueallow_unused: 允许输入变量不进入计算'''grads = autograd.grad(outputs= pred, inputs = mid,grad_outputs= torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()return gpdef main():lambd = 0.2 #超参数maxIter = 1000torch.manual_seed(10)np.random.seed(10)data_iter  = data_generator()G = Generator().to(device)D = Discriminator().to(device)G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))K = 5viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))for epoch in range(maxIter):#1: train Discrimator fistlyfor k in range(K):#1.1: train on real dataxr = next(data_iter)xr = torch.from_numpy(xr).to(device)predr = D(xr)#max(predr) == min(-predr)lossr = -predr.mean()#1.2: train on fake dataz = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()predf =D(xf)lossf = predf.mean()#1.3 gradient_penaltygp = gradient_penalty(D, xr,xf.detach())#aggregate allloss_D = lossr + lossf +lambd*gpoptim_D.zero_grad()loss_D.backward()optim_D.step()#print("\n Discriminator 训练结束 ",loss_D.item())# 2 train  Generator#2.1 train on fake dataz = torch.randn(batchsz, 2).to(device)xf = G(z)predf =D(xf) #期望最大loss_G= -predf.mean()#optimizeoptim_G.zero_grad()loss_G.backward()optim_G.step()if epoch %100 ==0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())if __name__ == "__main__":main()

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

这篇关于[PyTorch][chapter 57][WGAN-GP 代码实现]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/165261

相关文章

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

springboot循环依赖问题案例代码及解决办法

《springboot循环依赖问题案例代码及解决办法》在SpringBoot中,如果两个或多个Bean之间存在循环依赖(即BeanA依赖BeanB,而BeanB又依赖BeanA),会导致Spring的... 目录1. 什么是循环依赖?2. 循环依赖的场景案例3. 解决循环依赖的常见方法方法 1:使用 @La

Java枚举类实现Key-Value映射的多种实现方式

《Java枚举类实现Key-Value映射的多种实现方式》在Java开发中,枚举(Enum)是一种特殊的类,本文将详细介绍Java枚举类实现key-value映射的多种方式,有需要的小伙伴可以根据需要... 目录前言一、基础实现方式1.1 为枚举添加属性和构造方法二、http://www.cppcns.co

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

使用C#代码在PDF文档中添加、删除和替换图片

《使用C#代码在PDF文档中添加、删除和替换图片》在当今数字化文档处理场景中,动态操作PDF文档中的图像已成为企业级应用开发的核心需求之一,本文将介绍如何在.NET平台使用C#代码在PDF文档中添加、... 目录引言用C#添加图片到PDF文档用C#删除PDF文档中的图片用C#替换PDF文档中的图片引言在当

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java实现文件图片的预览和下载功能

《Java实现文件图片的预览和下载功能》这篇文章主要为大家详细介绍了如何使用Java实现文件图片的预览和下载功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... Java实现文件(图片)的预览和下载 @ApiOperation("访问文件") @GetMapping("