SEAN 代码略解

2024-03-06 20:38
文章标签 代码 略解 sean

本文主要是介绍SEAN 代码略解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这篇《SEAN: Image Synthesis with Semantic Region-Adaptive Normalization 》是2020年CVPR的一篇oral,对它的代码做一个梳理。

由于已经做过了关于SPADE的解析,这一篇主要是看看它在SPADE上有什么改进

不同之处一: models/networks/generator.py

    def __init__(self, opt):super().__init__()self.opt = optnf = opt.ngfself.sw, self.sh = self.compute_latent_vector_size(opt)self.Zencoder = Zencoder(3, 512)### 在SEAN中,是默认有一个vae的操作,所以这里要分析一下Zencoderself.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)final_nc = nfif opt.num_upsampling_layers == 'most':self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')final_nc = nf // 2self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)self.up = nn.Upsample(scale_factor=2)#self.up = nn.Upsample(scale_factor=2, mode='bilinear')
    def forward(self, input, rgb_img, obj_dic=None):seg = inputx = F.interpolate(seg, size=(self.sh, self.sw))x = self.fc(x)style_codes = self.Zencoder(input=rgb_img, segmap=seg)x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)if self.opt.num_upsampling_layers == 'more' or \self.opt.num_upsampling_layers == 'most':x = self.up(x)x = self.G_middle_1(x, seg, style_codes,  obj_dic=obj_dic)x = self.up(x)x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_3(x, seg, style_codes,  obj_dic=obj_dic)# if self.opt.num_upsampling_layers == 'most':#     x = self.up(x)#     x= self.up_4(x, seg, style_codes,  obj_dic=obj_dic)x = self.conv_img(F.leaky_relu(x, 2e-1))x = F.tanh(x)return x

不同之处二:models/networks/architecture.py

class Zencoder(torch.nn.Module):def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=2, norm_layer=nn.InstanceNorm2d):super(Zencoder, self).__init__()self.output_nc = output_ncmodel = [nn.ReflectionPad2d(1), nn.Conv2d(input_nc, ngf, kernel_size=3, padding=0),norm_layer(ngf), nn.LeakyReLU(0.2, False)]### downsamplefor i in range(n_downsampling):mult = 2**imodel += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),norm_layer(ngf * mult * 2), nn.LeakyReLU(0.2, False)]### upsamplefor i in range(1):mult = 2**(n_downsampling - i)model += [nn.ConvTranspose2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, output_padding=1),norm_layer(int(ngf * mult / 2)), nn.LeakyReLU(0.2, False)]###当output_padding=stride-1时,输出的特征图/输入的特征图=stridemodel += [nn.ReflectionPad2d(1), nn.Conv2d(256, output_nc, kernel_size=3, padding=0), nn.Tanh()]self.model = nn.Sequential(*model)def forward(self, input, segmap):codes = self.model(input) #input为style image, 通道为512维,大小和input一样大的特征向量图segmap = F.interpolate(segmap, size=codes.size()[2:], mode='nearest')# print(segmap.shape)# print(codes.shape)b_size = codes.shape[0]# h_size = codes.shape[2]# w_size = codes.shape[3]f_size = codes.shape[1]s_size = segmap.shape[1]codes_vector = torch.zeros((b_size, s_size, f_size), dtype=codes.dtype, device=codes.device)###下面这一步就是在做region-wise average poolingfor i in range(b_size):for j in range(s_size):component_mask_area = torch.sum(segmap.bool()[i, j])### segmap.bool()[i,j] 为第i个batch下的第j个label中的bool形式的mask### 经过torch.sum把这个mask下为true的值加了起来,得到范围在[0,H x W]的值if component_mask_area > 0:### 确保这个label下的segmap里的值不全为0(0意味着不属于任何label),也就是这一类标签是存在的,而不是为空的codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size,  component_mask_area).mean(1)### A.masked_select(mask)的用法:根据mask返回A中在mask里对应坐标值为True的值,返回值的大小为所有的True的值flatted后的一维向量### 当mask的大小与A的大小不相等时,会做广播### 对f_szie个维度上的有效区域求均值codes_vector[i][j] = codes_component_feature# codes_avg[i].masked_scatter_(segmap.bool()[i, j], codes_component_mu)return codes_vector#输出结果的大小为[B,s_size, f_size]
class SPADEResnetBlock(nn.Module):def __init__(self, fin, fout, opt, Block_Name=None, use_rgb=True):super().__init__()self.use_rgb = use_rgbself.Block_Name = Block_Nameself.status = opt.status# Attributesself.learned_shortcut = (fin != fout)fmiddle = min(fin, fout)# create conv layersself.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)if self.learned_shortcut:self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)# apply spectral norm if specifiedif 'spectral' in opt.norm_G:self.conv_0 = spectral_norm(self.conv_0)self.conv_1 = spectral_norm(self.conv_1)if self.learned_shortcut:self.conv_s = spectral_norm(self.conv_s)# define normalization layersspade_config_str = opt.norm_G.replace('spectral', '')###########  Modifications 1normtype_list = ['spadeinstance3x3', 'spadesyncbatch3x3', 'spadebatch3x3']our_norm_type = 'spadesyncbatch3x3'self.ace_0 = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_0', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb)###########  Modifications 1###########  Modifications 1self.ace_1 = ACE(our_norm_type, fmiddle, 3, ACE_Name= Block_Name + '_ACE_1', status=self.status, spade_params=[spade_config_str, fmiddle, opt.semantic_nc], use_rgb=use_rgb)###########  Modifications 1if self.learned_shortcut:self.ace_s = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_s', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb)# note the resnet block with SPADE also takes in |seg|,# the semantic segmentation map as inputdef forward(self, x, seg, style_codes, obj_dic=None):x_s = self.shortcut(x, seg, style_codes, obj_dic)###########  Modifications 1dx = self.ace_0(x, seg, style_codes, obj_dic)dx = self.conv_0(self.actvn(dx))dx = self.ace_1(dx, seg, style_codes, obj_dic)dx = self.conv_1(self.actvn(dx))###########  Modifications 1out = x_s + dxreturn outdef shortcut(self, x, seg, style_codes, obj_dic):if self.learned_shortcut:x_s = self.ace_s(x, seg, style_codes, obj_dic)x_s = self.conv_s(x_s)else:x_s = xreturn x_sdef actvn(self, x):return F.leaky_relu(x, 2e-1)

 

不同之处三: models/networks/normalization.py

class ACE(nn.Module):def __init__(self, config_text, norm_nc, label_nc, ACE_Name=None, status='train', spade_params=None, use_rgb=True):super().__init__()self.ACE_Name = ACE_Nameself.status = statusself.save_npy = Trueself.Spade = SPADE(*spade_params)self.use_rgb = use_rgbself.style_length = 512self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True)self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True)self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True)assert config_text.startswith('spade')parsed = re.search('spade(\D+)(\d)x\d', config_text)param_free_norm_type = str(parsed.group(1))ks = int(parsed.group(2))pw = ks // 2if param_free_norm_type == 'instance':self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)elif param_free_norm_type == 'syncbatch':self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)elif param_free_norm_type == 'batch':self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)else:raise ValueError('%s is not a recognized param-free norm type in SPADE'% param_free_norm_type)# The dimension of the intermediate embedding space. Yes, hardcoded.if self.use_rgb:self.create_gamma_beta_fc_layers()self.conv_gamma = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw)self.conv_beta = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw)def forward(self, x, segmap, style_codes=None, obj_dic=None):# Part 1. generate parameter-free normalized activationsadded_noise = (torch.randn(x.shape[0], x.shape[3], x.shape[2], 1).cuda() * self.noise_var).transpose(1, 3)normalized = self.param_free_norm(x + added_noise)# Part 2. produce scaling and bias conditioned on semantic mapsegmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')if self.use_rgb:[b_size, f_size, h_size, w_size] = normalized.shapemiddle_avg = torch.zeros((b_size, self.style_length, h_size, w_size), device=normalized.device)if self.status == 'UI_mode':############## hard codingfor i in range(1):for j in range(segmap.shape[1]):component_mask_area = torch.sum(segmap.bool()[i, j])if component_mask_area > 0:if obj_dic is None:print('wrong even it is the first input')else:style_code_tmp = obj_dic[str(j)]['ACE']middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_code_tmp))component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length,component_mask_area)middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu)else:for i in range(b_size):for j in range(segmap.shape[1]):component_mask_area = torch.sum(segmap.bool()[i, j])if component_mask_area > 0:middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_codes[i][j]))component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length, component_mask_area)middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu)if self.status == 'test' and self.save_npy and self.ACE_Name=='up_2_ACE_0':tmp = style_codes[i][j].cpu().numpy()dir_path = 'styles_test'############### some problem with obj_dic[i]im_name = os.path.basename(obj_dic[i])folder_path = os.path.join(dir_path, 'style_codes', im_name, str(j))if not os.path.exists(folder_path):os.makedirs(folder_path)style_code_path = os.path.join(folder_path, 'ACE.npy')np.save(style_code_path, tmp)gamma_avg = self.conv_gamma(middle_avg)beta_avg = self.conv_beta(middle_avg)gamma_spade, beta_spade = self.Spade(segmap)gamma_alpha = F.sigmoid(self.blending_gamma)beta_alpha = F.sigmoid(self.blending_beta)gamma_final = gamma_alpha * gamma_avg + (1 - gamma_alpha) * gamma_spadebeta_final = beta_alpha * beta_avg + (1 - beta_alpha) * beta_spadeout = normalized * (1 + gamma_final) + beta_finalelse:gamma_spade, beta_spade = self.Spade(segmap)gamma_final = gamma_spadebeta_final = beta_spadeout = normalized * (1 + gamma_final) + beta_finalreturn outdef create_gamma_beta_fc_layers(self):###################  These codes should be replaced with torch.nn.ModuleListstyle_length = self.style_lengthself.fc_mu0 = nn.Linear(style_length, style_length)self.fc_mu1 = nn.Linear(style_length, style_length)self.fc_mu2 = nn.Linear(style_length, style_length)self.fc_mu3 = nn.Linear(style_length, style_length)self.fc_mu4 = nn.Linear(style_length, style_length)self.fc_mu5 = nn.Linear(style_length, style_length)self.fc_mu6 = nn.Linear(style_length, style_length)self.fc_mu7 = nn.Linear(style_length, style_length)self.fc_mu8 = nn.Linear(style_length, style_length)self.fc_mu9 = nn.Linear(style_length, style_length)self.fc_mu10 = nn.Linear(style_length, style_length)self.fc_mu11 = nn.Linear(style_length, style_length)self.fc_mu12 = nn.Linear(style_length, style_length)self.fc_mu13 = nn.Linear(style_length, style_length)self.fc_mu14 = nn.Linear(style_length, style_length)self.fc_mu15 = nn.Linear(style_length, style_length)self.fc_mu16 = nn.Linear(style_length, style_length)self.fc_mu17 = nn.Linear(style_length, style_length)self.fc_mu18 = nn.Linear(style_length, style_length)

未完待续

这篇关于SEAN 代码略解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/781285

相关文章

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1

html css jquery选项卡 代码练习小项目

在学习 html 和 css jquery 结合使用的时候 做好是能尝试做一些简单的小功能,来提高自己的 逻辑能力,熟悉代码的编写语法 下面分享一段代码 使用html css jquery选项卡 代码练习 <div class="box"><dl class="tab"><dd class="active">手机</dd><dd>家电</dd><dd>服装</dd><dd>数码</dd><dd

生信代码入门:从零开始掌握生物信息学编程技能

少走弯路,高效分析;了解生信云,访问 【生信圆桌x生信专用云服务器】 : www.tebteb.cc 介绍 生物信息学是一个高度跨学科的领域,结合了生物学、计算机科学和统计学。随着高通量测序技术的发展,海量的生物数据需要通过编程来进行处理和分析。因此,掌握生信编程技能,成为每一个生物信息学研究者的必备能力。 生信代码入门,旨在帮助初学者从零开始学习生物信息学中的编程基础。通过学习常用

husky 工具配置代码检查工作流:提交代码至仓库前做代码检查

提示:这篇博客以我前两篇博客作为先修知识,请大家先去看看我前两篇博客 博客指路:前端 ESlint 代码规范及修复代码规范错误-CSDN博客前端 Vue3 项目开发—— ESLint & prettier 配置代码风格-CSDN博客 husky 工具配置代码检查工作流的作用 在工作中,我们经常需要将写好的代码提交至代码仓库 但是由于程序员疏忽而将不规范的代码提交至仓库,显然是不合理的 所

Unity3D自带Mouse Look鼠标视角代码解析。

Unity3D自带Mouse Look鼠标视角代码解析。 代码块 代码块语法遵循标准markdown代码,例如: using UnityEngine;using System.Collections;/// MouseLook rotates the transform based on the mouse delta./// Minimum and Maximum values can