本文主要是介绍starGAN原理代码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
下载:
git clone https://github.com/yunjey/StarGAN.git
cd StarGAN/
下载celebA训练数据:
bash download.sh
训练:
python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \--sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \--model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'
代码分析
生成网络
第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))
2个卷积层,stride=2,即下采样,
# Down-Sampling
curr_dim = conv_dim
for i in range(2):layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))layers.append(nn.ReLU(inplace=True))curr_dim = curr_dim * 2
残差层,
# Bottleneck
for i in range(repeat_num):layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
残差网络结构,
class ResidualBlock(nn.Module):"""Residual Block."""def __init__(self, dim_in, dim_out):super(ResidualBlock, self).__init__()self.main = nn.Sequential(nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),nn.InstanceNorm2d(dim_out, affine=True),nn.ReLU(inplace=True),nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),nn.InstanceNorm2d(dim_out, affine=True))def forward(self, x):return x + self.main(x)
上采样,
# Up-Sampling
for i in range(2):layers.append(nn
这篇关于starGAN原理代码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!