Paper reading: (CVPR2023 SDT) Handwritten text generation and source code correspondence based on the decoupling of writer's style and character style

introduction

  • I haven't read the paper seriously for a long time, so don't pick it up quickly. This is also the first CV paper I read that uses the Transformer structure.
  • The reason why I chose this article to read is because I have done a handwritten font generation project before. This work can be used to synthesize some handwriting data sets to assist the training of handwriting recognition models.
  • This article will be written from the way of one-to-one analysis of the paper and the code, so that it is easy to find the key points of the paper and how to implement it with code, and learn the key points faster. The code for this project is beautifully written, with clear instructions and neat code specifications. Follow the warehouse README to quickly run the entire project.
  • If the reader can read English, it is recommended to read the English paper directly first, so as to see the whole appearance more directly.
  • PDF | Code

Introduction to the overall structure of SDT

  • Overall framework:
    SDT
  • This work proposes to decouple writer- and character-level style representations from individual handwriting to synthesize realistic stylized online handwritten characters.
  • From the above frame diagram, it can be seen that the whole can be divided into three parts: Style encoder , Content Encoder and Transformer Decoder .
    • Style Encoder : It mainly learns the two style representations of Writer and Glyph of a given Style, which are used to guide the synthesis of stylized text. Contains two parts: CNN Encoder and Transformer Encdoer .
    • Content Encoder : It mainly extracts the features of the input text, and also includes two parts: CNN Encoder and Transformer Encdoer .
  • ❓Question: Why use CNN Encoder + Transformer Encoder in combination?
    • This question only said in the paper that the Content Encoder uses both. The CNN part is used to learn compact feature map from content reference . Transformer encoder is used to extract textual content representation. Thanks to Transformer's powerful ability to capture long-range dependencies, Content Encdoer can obtain a global context content feature. This reminds me of the classic CRNN structure, which is to combine the two parts of CNN + RNN.
      insert image description here

Correspondence between code and paper

  • The core code of the paper structure has two parts, one is the model building part, and the other is the data set processing part.
Build the model part
  • This part of the code is located in models/model.py in the warehouse . I will only extract the most critical part and add comments to explain it. Please dig out the rest of the details by yourself.
class SDT_Generator(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1,
                 wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=True, return_intermediate_dec=True):
        super(SDT_Generator, self).__init__()
        
        ### style encoder with dual heads
        # Feat_Encoder:对应论文中的CNN Encoder,用来提取图像经过CNN之后的特征,backbone选的是ResNet18
        self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
        
        # self.base_encoder:对应论文中Style Encoder的Transformer Encoderb部分
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None)
        
        writer_norm = nn.LayerNorm(d_model) if normalize_before else None
        glyph_norm = nn.LayerNorm(d_model) if normalize_before else None
 
        # writer_head和glyph_head分别对应论文中的Writer Head和Glyph Head
        # 从这里来看,这两个分支使用的是1层的Transformer Encoder结构
        self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)
        self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)

        ### content ecoder
        # content_encoder:对应论文中Content Encoder部分,
        # 从Content_TR源码来看,同样也是ResNet18作为CNN Encoder的backbone
        # Transformer Encoder部分用了3层的Transformer Encoder结构
        # 详情参见:https://github.com/dailenson/SDT/blob/1352b5cb779d47c5a8c87f6735e9dde94aa58f07/models/encoder.py#L8
        self.content_encoder = Content_TR(d_model, num_encoder_layers)

        ### decoder for receiving writer-wise and character-wise styles
        # 这里对应框图中Transformer Decoder中前后两个部分
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,
                                              return_intermediate=return_intermediate_dec)
        gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,
                                          return_intermediate=return_intermediate_dec)
        
        ### two mlps that project style features into the space where nce_loss is applied
        self.pro_mlp_writer = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))
        self.pro_mlp_character = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))

        self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)
        self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)
  
        # 这里位置嵌入来源于论文Attention is all you need.
        self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)        
        self._reset_parameters()

    # the shape of style_imgs is [B, 2*N, C, H, W] during training
    def forward(self, style_imgs, seq, char_img):
        batch_size, num_imgs, in_planes, h, w = style_imgs.shape

        # style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512]
        style_imgs = style_imgs.view(-1, in_planes, h, w)  # [B*2N, C:1, H, W]
        
        # 经过CNN Encoder
        style_embe = self.Feat_Encoder(style_imgs)  # [B*2N, C:512, 2, 2]

        anchor_num = num_imgs//2
        style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1)  # [4, B*2N, C:512]
        FEAT_ST_ENC = self.add_position(style_embe)

        memory = self.base_encoder(FEAT_ST_ENC)  # [4, B*2N, C]
        writer_memory = self.writer_head(memory)
        glyph_memory = self.glyph_head(memory)

        writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]
        glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]

        # writer-nce
        memory_fea = rearrange(writer_memory, 't b n c ->(t n) b c')  # [4*N, 2*B, C]
        compact_fea = torch.mean(memory_fea, 0) # [2*B, C]
        
        # compact_fea:[2*B, C:512] ->  nce_emb: [B, 2, C:128]
        pro_emb = self.pro_mlp_writer(compact_fea)
        query_emb = pro_emb[:batch_size, :]
        pos_emb = pro_emb[batch_size:, :]
        nce_emb = torch.stack((query_emb, pos_emb), 1) # [B, 2, C]
        nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)

        # glyph-nce
        patch_emb = glyph_memory[:, :batch_size]  # [4, B, N, C]
        
        # sample the positive pair
        anc, positive = self.random_double_sampling(patch_emb)
        n_channels = anc.shape[-1]
        anc = anc.reshape(batch_size, -1, n_channels)
        anc_compact = torch.mean(anc, 1, keepdim=True) 
        anc_compact = self.pro_mlp_character(anc_compact) # [B, 1, C]
        positive = positive.reshape(batch_size, -1, n_channels)
        positive_compact = torch.mean(positive, 1, keepdim=True)
        positive_compact = self.pro_mlp_character(positive_compact) # [B, 1, C]

        nce_emb_patch = torch.cat((anc_compact, positive_compact), 1) # [B, 2, C]
        nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)

        # input the writer-wise & character-wise styles into the decoder
        writer_style = memory_fea[:, :batch_size, :]  # [4*N, B, C]
        glyph_style = glyph_memory[:, :batch_size]  # [4, B, N, C]
        glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c') # [4*N, B, C]

        # QUERY: [char_emb, seq_emb]
        seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2)
        T, N, C = seq_emb.shape

        # ========================Content Encoder部分=========================
        char_emb = self.content_encoder(char_img) # [4, N, 512]
        char_emb = torch.mean(char_emb, 0) #[N, 512]
        char_emb = repeat(char_emb, 'n c -> t n c', t = 1)
        tgt = torch.cat((char_emb, seq_emb), 0) # [1+T], put the content token as the first token
        tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)
        tgt = self.add_position(tgt)

		# 注意这里的执行顺序,Content Encoder输出 → Writer Decoder → Glyph Decoder → Embedding to Sequence
        # [wri_dec_layers, T, B, C]
        wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)
        # [gly_dec_layers, T, B, C]
        hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)  

        h = hs.transpose(1, 2)[-1]  # B T C
        pred_sequence = self.EmbtoSeq(h)
        return pred_sequence, nce_emb, nce_emb_patch
Dataset section
  • CASIA_CHINESE
    data/CASIA_CHINESE
    ├── character_dict.pkl   # 词典
    ├── Chinese_content.pkl  # Content reference
    ├── test
    ├── test_style_samples
    ├── train
    ├── train_style_samples  # 1300个pkl,每个pkl中是同一个人写的各个字,长度不一致
    └── writer_dict.pkl
    
  • Analysis of a single data format in the training set
    {
          
          
        'coords': torch.Tensor(coords),                # 写这个字,每一划的点阵
        'character_id': torch.Tensor([character_id]),  # content字的索引
        'writer_id': torch.Tensor([writer_id]),        # 某个人的style
        'img_list': torch.Tensor(img_list),            # 随机选中style的img_list
        'char_img': torch.Tensor(char_img),            # content字的图像
        'img_label': torch.Tensor([label_id]),         # style中图像的label
    }
    
  • When reasoning:
    • enter:
      • A style image of 15 characters
      • raw input characters
    • Output: the original characters belonging to the style

Summarize

  1. I feel that the usage of Transformer is relatively rough. Of course, Transformer is inherently rough
  2. The model 69M ( position_layer2_dim512_iter138k_test_acc0.9443.pth) is easier to accept. It is quite different from the Transformer series I thought it was. This can be regarded as correcting one's own blind cognition.
  3. I learned einopsthe usage of the library and semantic operation, which is very interesting and worth learning.
  4. The first time I learned about the Loss of NCE (Noise Contrastive Estimation), it mainly solves the problem of converting it into a binary classification problem when there are many classes.

Guess you like

Origin blog.csdn.net/shiwanghualuo/article/details/131430113