I haven't read the paper seriously for a long time, so don't pick it up quickly. This is also the first CV paper I read that uses the Transformer structure.
The reason why I chose this article to read is because I have done a handwritten font generation project before. This work can be used to synthesize some handwriting data sets to assist the training of handwriting recognition models.
This article will be written from the way of one-to-one analysis of the paper and the code, so that it is easy to find the key points of the paper and how to implement it with code, and learn the key points faster. The code for this project is beautifully written, with clear instructions and neat code specifications. Follow the warehouse README to quickly run the entire project.
If the reader can read English, it is recommended to read the English paper directly first, so as to see the whole appearance more directly.
This work proposes to decouple writer- and character-level style representations from individual handwriting to synthesize realistic stylized online handwritten characters.
From the above frame diagram, it can be seen that the whole can be divided into three parts: Style encoder , Content Encoder and Transformer Decoder .
Style Encoder : It mainly learns the two style representations of Writer and Glyph of a given Style, which are used to guide the synthesis of stylized text. Contains two parts: CNN Encoder and Transformer Encdoer .
Content Encoder : It mainly extracts the features of the input text, and also includes two parts: CNN Encoder and Transformer Encdoer .
❓Question: Why use CNN Encoder + Transformer Encoder in combination?
This question only said in the paper that the Content Encoder uses both. The CNN part is used to learn compact feature map from content reference . Transformer encoder is used to extract textual content representation. Thanks to Transformer's powerful ability to capture long-range dependencies, Content Encdoer can obtain a global context content feature. This reminds me of the classic CRNN structure, which is to combine the two parts of CNN + RNN.
Correspondence between code and paper
The core code of the paper structure has two parts, one is the model building part, and the other is the data set processing part.
Build the model part
This part of the code is located in models/model.py in the warehouse . I will only extract the most critical part and add comments to explain it. Please dig out the rest of the details by yourself.
classSDT_Generator(nn.Module):def__init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers=1,
wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=True, return_intermediate_dec=True):super(SDT_Generator, self).__init__()### style encoder with dual heads# Feat_Encoder:对应论文中的CNN Encoder,用来提取图像经过CNN之后的特征,backbone选的是ResNet18
self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1,64, kernel_size=7, stride=2, padding=3, bias=False)]+list(models.resnet18(pretrained=True).children())[1:-2]))# self.base_encoder:对应论文中Style Encoder的Transformer Encoderb部分
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers,None)
writer_norm = nn.LayerNorm(d_model)if normalize_before elseNone
glyph_norm = nn.LayerNorm(d_model)if normalize_before elseNone# writer_head和glyph_head分别对应论文中的Writer Head和Glyph Head# 从这里来看,这两个分支使用的是1层的Transformer Encoder结构
self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)
self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)### content ecoder# content_encoder:对应论文中Content Encoder部分,# 从Content_TR源码来看,同样也是ResNet18作为CNN Encoder的backbone# Transformer Encoder部分用了3层的Transformer Encoder结构# 详情参见:https://github.com/dailenson/SDT/blob/1352b5cb779d47c5a8c87f6735e9dde94aa58f07/models/encoder.py#L8
self.content_encoder = Content_TR(d_model, num_encoder_layers)### decoder for receiving writer-wise and character-wise styles# 这里对应框图中Transformer Decoder中前后两个部分
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
wri_decoder_norm = nn.LayerNorm(d_model)if normalize_before elseNone
self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,
return_intermediate=return_intermediate_dec)
gly_decoder_norm = nn.LayerNorm(d_model)if normalize_before elseNone
self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,
return_intermediate=return_intermediate_dec)### two mlps that project style features into the space where nce_loss is applied
self.pro_mlp_writer = nn.Sequential(
nn.Linear(512,4096), nn.GELU(), nn.Linear(4096,256))
self.pro_mlp_character = nn.Sequential(
nn.Linear(512,4096), nn.GELU(), nn.Linear(4096,256))
self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)
self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)# 这里位置嵌入来源于论文Attention is all you need.
self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)
self._reset_parameters()# the shape of style_imgs is [B, 2*N, C, H, W] during trainingdefforward(self, style_imgs, seq, char_img):
batch_size, num_imgs, in_planes, h, w = style_imgs.shape
# style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512]
style_imgs = style_imgs.view(-1, in_planes, h, w)# [B*2N, C:1, H, W]# 经过CNN Encoder
style_embe = self.Feat_Encoder(style_imgs)# [B*2N, C:512, 2, 2]
anchor_num = num_imgs//2
style_embe = style_embe.view(batch_size*num_imgs,512,-1).permute(2,0,1)# [4, B*2N, C:512]
FEAT_ST_ENC = self.add_position(style_embe)
memory = self.base_encoder(FEAT_ST_ENC)# [4, B*2N, C]
writer_memory = self.writer_head(memory)
glyph_memory = self.glyph_head(memory)
writer_memory = rearrange(writer_memory,'t (b p n) c -> t (p b) n c',
b=batch_size, p=2, n=anchor_num)# [4, 2*B, N, C]
glyph_memory = rearrange(glyph_memory,'t (b p n) c -> t (p b) n c',
b=batch_size, p=2, n=anchor_num)# [4, 2*B, N, C]# writer-nce
memory_fea = rearrange(writer_memory,'t b n c ->(t n) b c')# [4*N, 2*B, C]
compact_fea = torch.mean(memory_fea,0)# [2*B, C]# compact_fea:[2*B, C:512] -> nce_emb: [B, 2, C:128]
pro_emb = self.pro_mlp_writer(compact_fea)
query_emb = pro_emb[:batch_size,:]
pos_emb = pro_emb[batch_size:,:]
nce_emb = torch.stack((query_emb, pos_emb),1)# [B, 2, C]
nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)# glyph-nce
patch_emb = glyph_memory[:,:batch_size]# [4, B, N, C]# sample the positive pair
anc, positive = self.random_double_sampling(patch_emb)
n_channels = anc.shape[-1]
anc = anc.reshape(batch_size,-1, n_channels)
anc_compact = torch.mean(anc,1, keepdim=True)
anc_compact = self.pro_mlp_character(anc_compact)# [B, 1, C]
positive = positive.reshape(batch_size,-1, n_channels)
positive_compact = torch.mean(positive,1, keepdim=True)
positive_compact = self.pro_mlp_character(positive_compact)# [B, 1, C]
nce_emb_patch = torch.cat((anc_compact, positive_compact),1)# [B, 2, C]
nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)# input the writer-wise & character-wise styles into the decoder
writer_style = memory_fea[:,:batch_size,:]# [4*N, B, C]
glyph_style = glyph_memory[:,:batch_size]# [4, B, N, C]
glyph_style = rearrange(glyph_style,'t b n c -> (t n) b c')# [4*N, B, C]# QUERY: [char_emb, seq_emb]
seq_emb = self.SeqtoEmb(seq).permute(1,0,2)
T, N, C = seq_emb.shape
# ========================Content Encoder部分=========================
char_emb = self.content_encoder(char_img)# [4, N, 512]
char_emb = torch.mean(char_emb,0)#[N, 512]
char_emb = repeat(char_emb,'n c -> t n c', t =1)
tgt = torch.cat((char_emb, seq_emb),0)# [1+T], put the content token as the first token
tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)
tgt = self.add_position(tgt)# 注意这里的执行顺序,Content Encoder输出 → Writer Decoder → Glyph Decoder → Embedding to Sequence# [wri_dec_layers, T, B, C]
wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)# [gly_dec_layers, T, B, C]
hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)
h = hs.transpose(1,2)[-1]# B T C
pred_sequence = self.EmbtoSeq(h)return pred_sequence, nce_emb, nce_emb_patch
Output: the original characters belonging to the style
Summarize
I feel that the usage of Transformer is relatively rough. Of course, Transformer is inherently rough
The model 69M ( position_layer2_dim512_iter138k_test_acc0.9443.pth) is easier to accept. It is quite different from the Transformer series I thought it was. This can be regarded as correcting one's own blind cognition.
I learned einopsthe usage of the library and semantic operation, which is very interesting and worth learning.
The first time I learned about the Loss of NCE (Noise Contrastive Estimation), it mainly solves the problem of converting it into a binary classification problem when there are many classes.