【Pytorch】生成对抗网络实战

2024-08-26 13:12

本文主要是介绍【Pytorch】生成对抗网络实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

GAN框架基于两个模型的竞争,Generator生成器和Discriminator鉴别器。生成器生成假图像,鉴别器则尝试从假图像中识别真实的图像。作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

目录

创建数据集

定义生成器

定义鉴别器

初始化模型权重

定义损失函数

定义优化器

训练模型

部署生成器


创建数据集

使用 PyTorch torchvision 包中提供的 STL-10 数据集,数据集中有 10 个类:飞机、鸟、车、猫、鹿、狗、马、猴、船、卡车。图像为96*96像素的RGB图像。数据集包含 5,000 张训练图像和 8,000 张测试图像。在训练数据集和测试数据集中,每个类分别有 500 和 800 张图像。

 STL-10数据集详细参考http://t.csdnimg.cn/ojBn6中数据加载和处理部分 

from torchvision import datasets
import torchvision.transforms as transforms
import os# 定义数据集路径
path2data="./data"
# 创建数据集路径
os.makedirs(path2data, exist_ok= True)# 定义图像尺寸
h, w = 64, 64
# 定义均值
mean = (0.5, 0.5, 0.5)
# 定义标准差
std = (0.5, 0.5, 0.5)
# 定义数据预处理
transform= transforms.Compose([transforms.Resize((h,w)),  # 调整图像尺寸transforms.CenterCrop((h,w)),  # 中心裁剪transforms.ToTensor(),  # 转换为张量transforms.Normalize(mean, std)])  # 归一化# 加载训练集
train_ds=datasets.STL10(path2data, split='train', download=False,transform=transform)

 展示示例图像张量形状、最小值和最大值

import torch
for x, _ in train_ds:print(x.shape, torch.min(x), torch.max(x))break

 展示示例图像

from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
plt.imshow(to_pil_image(0.5*x+0.5))

 

创建数据加载器 

import torch
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

 示例

for x,y in train_dl:print(x.shape, y.shape)break

定义生成器

GAN框架是基于两个模型的竞争,generator生成器和discriminator鉴别器。生成器生成假图像,鉴别器尝试从假图像中识别真实的图像。

作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

定义生成器模型 

from torch import nn
import torch.nn.functional as Fclass Generator(nn.Module):def __init__(self, params):super(Generator, self).__init__()# 获取参数nz = params["nz"]ngf = params["ngf"]noc = params["noc"]# 定义反卷积层1self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,stride=1, padding=0, bias=False)# 定义批归一化层1self.bn1 = nn.BatchNorm2d(ngf * 8)# 定义反卷积层2self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层2self.bn2 = nn.BatchNorm2d(ngf * 4)# 定义反卷积层3self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层3self.bn3 = nn.BatchNorm2d(ngf * 2)# 定义反卷积层4self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层4self.bn4 = nn.BatchNorm2d(ngf)# 定义反卷积层5self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, stride=2, padding=1, bias=False)# 前向传播def forward(self, x):# 反卷积层1x = F.relu(self.bn1(self.dconv1(x)))# 反卷积层2x = F.relu(self.bn2(self.dconv2(x)))            # 反卷积层3x = F.relu(self.bn3(self.dconv3(x)))        # 反卷积层4x = F.relu(self.bn4(self.dconv4(x)))    # 反卷积层5out = torch.tanh(self.dconv5(x))return out

设定生成器模型参数、移动模型到cuda设备并打印模型结构 

params_gen = {"nz": 100,"ngf": 64,"noc": 3,}
model_gen = Generator(params_gen)
device = torch.device("cuda:0")
model_gen.to(device)
print(model_gen)

定义鉴别器

定义鉴别器模型, 用于鉴别真实图像

class Discriminator(nn.Module):def __init__(self, params):super(Discriminator, self).__init__()# 获取参数nic= params["nic"]ndf = params["ndf"]# 定义卷积层1self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)# 定义卷积层2self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层2self.bn2 = nn.BatchNorm2d(ndf * 2)            # 定义卷积层3self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层3self.bn3 = nn.BatchNorm2d(ndf * 4)# 定义卷积层4self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层4self.bn4 = nn.BatchNorm2d(ndf * 8)# 定义卷积层5self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)def forward(self, x):# 使用leaky_relu激活函数对卷积层1的输出进行激活x = F.leaky_relu(self.conv1(x), 0.2, True)# 使用leaky_relu激活函数对卷积层2的输出进行激活,并使用批归一化层2进行批归一化x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)# 使用leaky_relu激活函数对卷积层3的输出进行激活,并使用批归一化层3进行批归一化x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)# 使用leaky_relu激活函数对卷积层4的输出进行激活,并使用批归一化层4进行批归一化x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        # 使用sigmoid激活函数对卷积层5的输出进行激活,并返回结果# Sigmoid激活函数是一种常用的非线性激活函数,它将输入值压缩到0和1之间,[ \sigma(x) = \frac{1}{1 + e^{-x}} ]out = torch.sigmoid(self.conv5(x))return out.view(-1)

设置模型参数,移动模型到cuda设备,打印模型结构 


params_dis = {"nic": 3,"ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)

初始化模型权重

定义函数,初始化模型权重 

def initialize_weights(model):# 获取模型类的名称classname = model.__class__.__name__# 如果模型类名称中包含'Conv',则初始化权重为均值为0,标准差为0.02的正态分布if classname.find('Conv') != -1:nn.init.normal_(model.weight.data, 0.0, 0.02)# 如果模型类名称中包含'BatchNorm',则初始化权重为均值为1,标准差为0.02的正态分布,偏置为0elif classname.find('BatchNorm') != -1:nn.init.normal_(model.weight.data, 1.0, 0.02)nn.init.constant_(model.bias.data, 0)

初始化生成器模型和鉴别器模型的权重 

# 对生成器模型应用初始化权重函数
model_gen.apply(initialize_weights);
# 对判别器模型应用初始化权重函数
model_dis.apply(initialize_weights);

定义损失函数

定义二元交叉熵(BCE)损失函数 

loss_func = nn.BCELoss()

定义优化器

定义Adam优化器

from torch import optim
# 学习率
lr = 2e-4 
# Adam优化器的beta1参数
beta1 = 0.5
# 定义鉴别器模型的优化器,学习率为lr,beta1参数为beta1,beta2参数为0.999
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))
# 定义生成器模型的优化器
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))

训练模型

 示例训练1000个epochs

# 定义真实标签和虚假标签
real_label = 1
fake_label = 0
# 获取生成器的噪声维度
nz = params_gen["nz"]
# 设置训练轮数
num_epochs = 1000
# 定义损失历史记录
loss_history={"gen": [],"dis": []}
# 定义批次数
batch_count = 0
# 遍历训练轮数
for epoch in range(num_epochs):# 遍历训练数据for xb, yb in train_dl:# 获取批大小ba_si = xb.size(0)# 将判别器梯度置零model_dis.zero_grad()# 将输入数据移动到指定设备xb = xb.to(device)# 将标签数据转换为指定设备yb = torch.full((ba_si,), real_label, device=device)# 判别器输出out_dis = model_dis(xb)# 将输出和标签转换为浮点数out_dis = out_dis.float()yb = yb.float()# 计算真实样本的损失loss_r = loss_func(out_dis, yb)# 反向传播loss_r.backward()# 生成噪声noise = torch.randn(ba_si, nz, 1, 1, device=device)# 生成器输出out_gen = model_gen(noise)# 判别器输出out_dis = model_dis(out_gen.detach())# 将标签数据填充为虚假标签yb.fill_(fake_label)    # 计算虚假样本的损失loss_f = loss_func(out_dis, yb)# 反向传播loss_f.backward()# 计算判别器的总损失loss_dis = loss_r + loss_f  # 更新判别器的参数opt_dis.step()   # 将生成器梯度置零model_gen.zero_grad()# 将标签数据填充为真实标签yb.fill_(real_label)  # 判别器输出out_dis = model_dis(out_gen)# 计算生成器的损失loss_gen = loss_func(out_dis, yb)# 反向传播loss_gen.backward()# 更新生成器的参数opt_gen.step()# 记录生成器和判别器的损失loss_history["gen"].append(loss_gen.item())loss_history["dis"].append(loss_dis.item())# 更新批次数batch_count += 1# 每100个批打印一次损失if batch_count % 100 == 0:print(epoch, loss_gen.item(),loss_dis.item())

 绘制损失图像

plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()

存储模型权重 

import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)

部署生成器

通常情况下,训练完成后放弃鉴别器模型而保留生成器模型,部署经过训练的生成器来生成新的图像。为部署生成器模型,将训练好的权重加载到模型中,然后给模型提供随机噪声。

# 加载生成器模型的权重
weights = torch.load(path2weights_gen)
# 将权重加载到生成器模型中
model_gen.load_state_dict(weights)
# 将生成器模型设置为评估模式
model_gen.eval()

 生成图像

import numpy as np
with torch.no_grad():# 生成固定噪声fixed_noise = torch.randn(16, nz, 1, 1, device=device)# 打印噪声形状print(fixed_noise.shape)# 生成假图像img_fake = model_gen(fixed_noise).detach().cpu()    
# 打印假图像形状
print(img_fake.shape)
# 创建画布
plt.figure(figsize=(10,10))
# 遍历假图像
for ii in range(16):# 在画布上绘制图像plt.subplot(4,4,ii+1)# 将图像转换为PIL图像plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))# 关闭坐标轴plt.axis("off")

其中一些可能看起来扭曲,而另一些看起来相对真实。为改进结果,可以在单个数据类上训练模型,而不是在多个类上一起训练。GAN在使用单个类进行训练时表现更好。此外,可以尝试更长时间地训练模型。

这篇关于【Pytorch】生成对抗网络实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1108663

相关文章

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2