论文阅读: (CVPR2023 SDT )基于书写者风格和字符风格解耦的手写文字生成及源码对应

本文主要是介绍论文阅读: (CVPR2023 SDT )基于书写者风格和字符风格解耦的手写文字生成及源码对应,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

        • 引言
        • SDT整体结构介绍
        • 代码与论文对应
          • 搭建模型部分
          • 数据集部分
        • 总结

引言
  • 许久不认真看论文了,这不赶紧捡起来。这也是自己看的第一篇用到Transformer结构的CV论文。
  • 之所以选择这篇文章来看,是考虑到之前做过手写字体生成的项目。这个工作可以用来合成一些手写体数据集,用来辅助手写体识别模型的训练。
  • 本篇文章将从论文与代码一一对应解析的方式来撰写,这样便于找到论文重点地方以及用代码如何实现的,更快地学到其中要点。这个项目的代码写得很好看,有着清晰的说明和整洁的代码规范。跟着仓库README就可以快速跑起整个项目。
  • 如果读者可以阅读英文的话,建议先去直接阅读英文论文,会更直接看到整个面貌。
  • PDF | Code
SDT整体结构介绍
  • 整体框架:
    SDT
  • 该工作提出从个体手写中解耦作家和字符级别的风格表示,以合成逼真的风格化在线手写字符。
  • 从上述框架图,可以看出整体可分为三大部分:Style encoderContent EncoderTransformer Decoder
    • Style Encoder: 主要学习给定的Style的Writer和Glyph两种风格表示,用于指导合成风格化的文字。包含两部分:CNN EncoderTransformer Encdoer
    • Content Encoder: 主要提取输入文字的特征,同样包含两部分:CNN EncoderTransformer Encdoer
  • ❓疑问:为什么要将CNN Encoder + Transformer Encoder结合使用呢?
    • 这个问题在论文中只说了Content Encoder使用两者的作用。CNN部分用来从content reference中学到compact feature map。Transformer encoder用来提取textual content表示。得益于Transformer强大的long-range 依赖的捕捉能力,Content Encdoer可以得到一个全局上下文的content feature。这里让我想到经典的CRNN结构,就是结合CNN + RNN两部分。
      在这里插入图片描述
代码与论文对应
  • 论文结构的最核心代码有两部分,一是搭建模型部分,二是数据集处理部分。
搭建模型部分
  • 该部分代码位于仓库中models/model.py,我这里只摘其中最关键部分添加注释来解释,其余细节请小伙伴自行挖掘。
class SDT_Generator(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1,wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=True, return_intermediate_dec=True):super(SDT_Generator, self).__init__()### style encoder with dual heads# Feat_Encoder:对应论文中的CNN Encoder,用来提取图像经过CNN之后的特征,backbone选的是ResNet18self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))# self.base_encoder:对应论文中Style Encoder的Transformer Encoderb部分encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None)writer_norm = nn.LayerNorm(d_model) if normalize_before else Noneglyph_norm = nn.LayerNorm(d_model) if normalize_before else None# writer_head和glyph_head分别对应论文中的Writer Head和Glyph Head# 从这里来看,这两个分支使用的是1层的Transformer Encoder结构self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)### content ecoder# content_encoder:对应论文中Content Encoder部分,# 从Content_TR源码来看,同样也是ResNet18作为CNN Encoder的backbone# Transformer Encoder部分用了3层的Transformer Encoder结构# 详情参见:https://github.com/dailenson/SDT/blob/1352b5cb779d47c5a8c87f6735e9dde94aa58f07/models/encoder.py#L8self.content_encoder = Content_TR(d_model, num_encoder_layers)### decoder for receiving writer-wise and character-wise styles# 这里对应框图中Transformer Decoder中前后两个部分decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,return_intermediate=return_intermediate_dec)gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,return_intermediate=return_intermediate_dec)### two mlps that project style features into the space where nce_loss is appliedself.pro_mlp_writer = nn.Sequential(nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))self.pro_mlp_character = nn.Sequential(nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)# 这里位置嵌入来源于论文Attention is all you need.self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)        self._reset_parameters()# the shape of style_imgs is [B, 2*N, C, H, W] during trainingdef forward(self, style_imgs, seq, char_img):batch_size, num_imgs, in_planes, h, w = style_imgs.shape# style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512]style_imgs = style_imgs.view(-1, in_planes, h, w)  # [B*2N, C:1, H, W]# 经过CNN Encoderstyle_embe = self.Feat_Encoder(style_imgs)  # [B*2N, C:512, 2, 2]anchor_num = num_imgs//2style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1)  # [4, B*2N, C:512]FEAT_ST_ENC = self.add_position(style_embe)memory = self.base_encoder(FEAT_ST_ENC)  # [4, B*2N, C]writer_memory = self.writer_head(memory)glyph_memory = self.glyph_head(memory)writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c',b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c',b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]# writer-ncememory_fea = rearrange(writer_memory, 't b n c ->(t n) b c')  # [4*N, 2*B, C]compact_fea = torch.mean(memory_fea, 0) # [2*B, C]# compact_fea:[2*B, C:512] ->  nce_emb: [B, 2, C:128]pro_emb = self.pro_mlp_writer(compact_fea)query_emb = pro_emb[:batch_size, :]pos_emb = pro_emb[batch_size:, :]nce_emb = torch.stack((query_emb, pos_emb), 1) # [B, 2, C]nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)# glyph-ncepatch_emb = glyph_memory[:, :batch_size]  # [4, B, N, C]# sample the positive pairanc, positive = self.random_double_sampling(patch_emb)n_channels = anc.shape[-1]anc = anc.reshape(batch_size, -1, n_channels)anc_compact = torch.mean(anc, 1, keepdim=True) anc_compact = self.pro_mlp_character(anc_compact) # [B, 1, C]positive = positive.reshape(batch_size, -1, n_channels)positive_compact = torch.mean(positive, 1, keepdim=True)positive_compact = self.pro_mlp_character(positive_compact) # [B, 1, C]nce_emb_patch = torch.cat((anc_compact, positive_compact), 1) # [B, 2, C]nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)# input the writer-wise & character-wise styles into the decoderwriter_style = memory_fea[:, :batch_size, :]  # [4*N, B, C]glyph_style = glyph_memory[:, :batch_size]  # [4, B, N, C]glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c') # [4*N, B, C]# QUERY: [char_emb, seq_emb]seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2)T, N, C = seq_emb.shape# ========================Content Encoder部分=========================char_emb = self.content_encoder(char_img) # [4, N, 512]char_emb = torch.mean(char_emb, 0) #[N, 512]char_emb = repeat(char_emb, 'n c -> t n c', t = 1)tgt = torch.cat((char_emb, seq_emb), 0) # [1+T], put the content token as the first tokentgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)tgt = self.add_position(tgt)# 注意这里的执行顺序,Content Encoder输出 → Writer Decoder → Glyph Decoder → Embedding to Sequence# [wri_dec_layers, T, B, C]wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)# [gly_dec_layers, T, B, C]hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)  h = hs.transpose(1, 2)[-1]  # B T Cpred_sequence = self.EmbtoSeq(h)return pred_sequence, nce_emb, nce_emb_patch
数据集部分
  • CASIA_CHINESE
    data/CASIA_CHINESE
    ├── character_dict.pkl   # 词典
    ├── Chinese_content.pkl  # Content reference
    ├── test
    ├── test_style_samples
    ├── train
    ├── train_style_samples  # 1300个pkl,每个pkl中是同一个人写的各个字,长度不一致
    └── writer_dict.pkl
    
  • 训练集中单个数据格式解析
    {'coords': torch.Tensor(coords),                # 写这个字,每一划的点阵'character_id': torch.Tensor([character_id]),  # content字的索引'writer_id': torch.Tensor([writer_id]),        # 某个人的style'img_list': torch.Tensor(img_list),            # 随机选中style的img_list'char_img': torch.Tensor(char_img),            # content字的图像'img_label': torch.Tensor([label_id]),         # style中图像的label
    }
    
  • 推理时:
    • 输入:
      • 一种style15个字符的图像
      • 原始输入字符
    • 输出:属于该style的原始字符
总结
  1. 感觉对于Transformer的用法,比较粗暴。当然,Transformer本来就很粗暴
  2. 模型69M (position_layer2_dim512_iter138k_test_acc0.9443.pth) 比较容易接受,这和我之前以为的Transformer系列都很大,有些出入。这也算是纠正自己的盲目认知了
  3. 学到了einops库的用法,语义化操作,很有意思,值得学习。
  4. 第一次了解到NCE(Noise Contrastive Estimation)这个Loss,主要解决了class很多时,将其转换为二分类问题。

这篇关于论文阅读: (CVPR2023 SDT )基于书写者风格和字符风格解耦的手写文字生成及源码对应的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/812815

相关文章

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

高效录音转文字:2024年四大工具精选!

在快节奏的工作生活中,能够快速将录音转换成文字是一项非常实用的能力。特别是在需要记录会议纪要、讲座内容或者是采访素材的时候,一款优秀的在线录音转文字工具能派上大用场。以下推荐几个好用的录音转文字工具! 365在线转文字 直达链接:https://www.pdf365.cn/ 365在线转文字是一款提供在线录音转文字服务的工具,它以其高效、便捷的特点受到用户的青睐。用户无需下载安装任何软件,只

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

hdu 1102 uva 10397(最小生成树prim)

hdu 1102: 题意: 给一个邻接矩阵,给一些村庄间已经修的路,问最小生成树。 解析: 把已经修的路的权值改为0,套个prim()。 注意prim 最外层循坏为n-1。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstri

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言