GAN生成对抗网络:花卉生成

2023-12-08 19:30
文章标签 生成 网络 对抗 gan 花卉

本文主要是介绍GAN生成对抗网络:花卉生成,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 简介
  • 一、GAN生成对抗网络基础知识
  • 二、数据集介绍
  • 三、代码实现
    • 参数设置
    • 数据处理
    • 搭建网络
    • 定义优化器与损失函数
    • 训练网络
    • 保存网络
    • 结果展示
  • 总结


简介

本篇文章利用pytorch搭建GAN生成对抗网络实现花卉生成的任务

一、GAN生成对抗网络基础知识

关于GAN生成对抗网络的基础知识以下文章有详细讲解,可供参考:
GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)

二、数据集介绍

本文使用花卉数据集,该数据集包含了4317张图片,包含雏菊、蒲公英、玫瑰、向日葵、郁金香五种花卉,我已将数据集拆分为训练集和测试集两部分,本文仅使用了训练集部分,以下是数据集目录:
在这里插入图片描述在这里插入图片描述
数据集已放于以下链接,有需要可自行下载
花卉数据集

三、代码实现

参数设置

step1.参数continue_train:是否继续训练
step2.参数dir:训练集路径
step3.参数batch_size:单次训练图片量
step4.参数device:使用GPU
step5.参数epochs:训练周期
step6.参数generator_num:每k轮训练一次生成器
step7.参数discriminator_num:每k轮训练一次判别器

if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--continue_train', type=bool, default=False, help='continue training')parser.add_argument('--dir', type=str, default='./flowers/train', help='dataset path')parser.add_argument('--batch_size', type=int, default=50, help='batch size')parser.add_argument('--device', type=int, default=0, help='GPU id')parser.add_argument('--epochs', type=int, default=200, help='train epochs')parser.add_argument('--generator_num', type=int, default=5, help='train generator every k epochs')parser.add_argument('--discriminator_num', type=int, default=1, help='train discriminator every k epochs')args = parser.parse_args()main(args)

数据处理

step1.定义训练集中图像输入判别器前的transform操作
step2.准备Dataset与Dataloader

    transform = transforms.Compose([transforms.Resize((96, 96)),  # 将图片resize至 96 * 96transforms.ToTensor(),  # 转换为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])data_set = datasets.ImageFolder(root=args.dir, transform=transform)data_loader = dataloader.DataLoader(dataset=data_set, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)print('already load data...')

搭建网络

step1.生成器使用反卷积,最终输出3 * 96 * 96大小的图片,且像素值 ∈ [ − 1 , 1 ] ∈[-1,1] [1,1]
step2.生成器使用卷积,最终输出判别为真的概率

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),       # 512 × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),      # 256 × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),  # 128 × 16 × 16nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),      # 64 × 32 × 32nn.ConvTranspose2d(64, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # 64 * 32 * 32nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),         # 128 * 16 * 16nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),  # 256 * 8 * 8nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),  # 512 * 4 * 4nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)

定义优化器与损失函数

step1.生成器与判别器的优化器都使用Adam
step2.将损失函数使用二元交叉熵损失

    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999))loss = nn.BCELoss()print('already prepared optimizer and loss_function...')

训练网络

每discriminator_num轮:
step1.输入真图片让判别器鉴别
step2.生成器利用随机噪声生成图片,并让判别器鉴别
step3.计算判别器损失(真鉴别为真,假鉴别为假),反向传播后更新判别器参数
每generator_num轮:
step4.生成器利用随机噪声生成图片,并让判别器鉴别
step5.计算生成器损失(假鉴别为真),反向传播后更新生成器参数
step6.每100轮保存一次结果

    print('start training...')for epoch in range(args.epochs):print('epoch:{}'.format(epoch + 1))for i, data in enumerate(data_loader):if (i + 1) % args.discriminator_num == 0:optimizer_D.zero_grad()real_img = data[0]batchsize = len(real_img)real_img = real_img.cuda(args.device)out_D_real = model_D(real_img)real_labels = torch.ones(batchsize).cuda(args.device)loss_D_real = loss(out_D_real, real_labels)loss_D_real.backward()noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device)fake_img = model_G(noise)out_D_fake = model_D(fake_img)fake_labels = torch.zeros(batchsize).cuda(args.device)loss_D_fake = loss(out_D_fake, fake_labels)loss_D_fake.backward()optimizer_D.step()if (i + 1) % args.generator_num == 0:optimizer_G.zero_grad()real_img = data[0]batchsize = len(real_img)noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device)fake_img = model_G(noise)out_D_fake = model_D(fake_img)real_labels = torch.ones(batchsize).cuda(args.device)loss_G = loss(out_D_fake, real_labels)loss_G.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:fix_noise = torch.randn(40, 100, 1, 1).cuda(args.device)final_img = model_G(fix_noise)final_img = final_img * 0.5 + 0.5final_img = final_img.cpu()plt.figure(1)for i in range(40):img = final_img[i].detach().numpy()plt.subplot(5, 8, i+1)plt.imshow(np.transpose(img, (1, 2, 0)))plt.savefig("./outcome/{}.png".format(epoch + 1))plt.show()print('end training...')

保存网络

    torch.save(model_G.state_dict(), './generator.pt')torch.save(model_D.state_dict(), './discriminator.pt')print('already saved model...')

结果展示

训练3000轮后得到结果如下:
在这里插入图片描述

总结

以上就是利用生成对抗网络实现图像生成的介绍,完整代码如下:

import argparse
import torchvision.datasets as datasets
import torch.utils.data.dataloader as dataloader
import torchvision.transforms as transforms
import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as pltclass Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),       # 512 × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),      # 256 × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),  # 128 × 16 × 16nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),      # 64 × 32 × 32nn.ConvTranspose2d(64, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # 64 * 32 * 32nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),         # 128 * 16 * 16nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),  # 256 * 8 * 8nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),  # 512 * 4 * 4nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def main(args):transform = transforms.Compose([transforms.Resize((96, 96)),  # 将图片resize至 96 * 96transforms.ToTensor(),  # 转换为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])data_set = datasets.ImageFolder(root=args.dir, transform=transform)data_loader = dataloader.DataLoader(dataset=data_set, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)print('already load data...')model_G = Generator()model_D = Discriminator()if args.continue_train == True:model_G.load_state_dict(torch.load('./generator.pt'))model_D.load_state_dict(torch.load('./discriminator.pt'))model_G.train()model_D.train()print('already prepared model...')optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999))loss = nn.BCELoss()print('already prepared optimizer and loss_function...')if torch.cuda.is_available() == True:model_G.cuda(args.device)model_D.cuda(args.device)loss.cuda(args.device)print('already in GPU...')print('start training...')for epoch in range(args.epochs):print('epoch:{}'.format(epoch + 1))for i, data in enumerate(data_loader):if (i + 1) % args.discriminator_num == 0:optimizer_D.zero_grad()real_img = data[0]batchsize = len(real_img)real_img = real_img.cuda(args.device)out_D_real = model_D(real_img)real_labels = torch.ones(batchsize).cuda(args.device)loss_D_real = loss(out_D_real, real_labels)loss_D_real.backward()noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device)fake_img = model_G(noise)out_D_fake = model_D(fake_img)fake_labels = torch.zeros(batchsize).cuda(args.device)loss_D_fake = loss(out_D_fake, fake_labels)loss_D_fake.backward()optimizer_D.step()if (i + 1) % args.generator_num == 0:optimizer_G.zero_grad()real_img = data[0]batchsize = len(real_img)noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device)fake_img = model_G(noise)out_D_fake = model_D(fake_img)real_labels = torch.ones(batchsize).cuda(args.device)loss_G = loss(out_D_fake, real_labels)loss_G.backward()optimizer_G.step()if (epoch + 1) % 10 == 0:fix_noise = torch.randn(40, 100, 1, 1).cuda(args.device)final_img = model_G(fix_noise)final_img = final_img * 0.5 + 0.5final_img = final_img.cpu()plt.figure(1)for i in range(40):img = final_img[i].detach().numpy()plt.subplot(5, 8, i+1)plt.imshow(np.transpose(img, (1, 2, 0)))plt.savefig("./outcome/{}.png".format(epoch + 1))plt.show()print('end training...')torch.save(model_G.state_dict(), './generator.pt')torch.save(model_D.state_dict(), './discriminator.pt')print('already saved model...')if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--continue_train', type=bool, default=False, help='continue training')parser.add_argument('--dir', type=str, default='./flowers/train', help='dataset path')parser.add_argument('--batch_size', type=int, default=50, help='batch size')parser.add_argument('--device', type=int, default=0, help='GPU id')parser.add_argument('--epochs', type=int, default=3000, help='train epochs')parser.add_argument('--generator_num', type=int, default=5, help='train generator every k epochs')parser.add_argument('--discriminator_num', type=int, default=1, help='train discriminator every k epochs')args = parser.parse_args()main(args)

这篇关于GAN生成对抗网络:花卉生成的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/471063

相关文章

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

hdu 1102 uva 10397(最小生成树prim)

hdu 1102: 题意: 给一个邻接矩阵,给一些村庄间已经修的路,问最小生成树。 解析: 把已经修的路的权值改为0,套个prim()。 注意prim 最外层循坏为n-1。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstri

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

ASIO网络调试助手之一:简介

多年前,写过几篇《Boost.Asio C++网络编程》的学习文章,一直没机会实践。最近项目中用到了Asio,于是抽空写了个网络调试助手。 开发环境: Win10 Qt5.12.6 + Asio(standalone) + spdlog 支持协议: UDP + TCP Client + TCP Server 独立的Asio(http://www.think-async.com)只包含了头文件,不依

poj 3723 kruscal,反边取最大生成树。

题意: 需要征募女兵N人,男兵M人。 每征募一个人需要花费10000美元,但是如果已经招募的人中有一些关系亲密的人,那么可以少花一些钱。 给出若干的男女之间的1~9999之间的亲密关系度,征募某个人的费用是10000 - (已经征募的人中和自己的亲密度的最大值)。 要求通过适当的招募顺序使得征募所有人的费用最小。 解析: 先设想无向图,在征募某个人a时,如果使用了a和b之间的关系