論文読込:(CVPR2023 SDT)筆者のスタイルと文字スタイルの分離に基づく手書きテキスト生成とソースコード対応

序章

  • 長い間その論文を真剣に読んでいないので、すぐに手に取らないでください。これは、Transformer 構造を使用した初めて読んだ CV 論文でもあります。
  • この記事を読むことにした理由は、以前に手書きフォント生成プロジェクトを行ったことがあるからです。この作業を使用して、いくつかの手書きデータ セットを合成し、手書き認識モデルのトレーニングを支援できます。
  • この記事は論文とコードを 1 対 1 で分析する方法で書かれているため、論文の重要なポイントとそれをコードで実装する方法を見つけやすく、より早く重要なポイントを学ぶことができます。このプロジェクトのコードは美しく書かれており、明確な指示ときちんとしたコード仕様が含まれています。ウェアハウスの README に従って、プロジェクト全体をすばやく実行します。
  • 英語が読める読者であれば、全体の様子をより直接的に見るために、まず英語の論文を直接読むことをお勧めします。
  • PDF | コード

SDT の全体構造の紹介

  • 全体的な枠組み:
    SDT
  • この研究は、作家レベルおよび文字レベルのスタイル表現を個々の手書きから切り離して、リアルな様式化されたオンライン手書き文字を合成することを提案しています。
  • 上のフレーム図から、全体が 3 つの部分、スタイル エンコーダーコンテンツ エンコーダー、およびトランスフォーマー デコーダーに分割できることがわかります
    • スタイル エンコーダ: 主に、特定のスタイルのライターとグリフの 2 つのスタイル表現を学習します。これらは、様式化されたテキストの合成をガイドするために使用されます。CNN EncoderTransformer Encdoerの 2 つの部分が含まれています
    • Content Encoder : 主に入力テキストの特徴を抽出します。また、CNN EncoderTransformer Encdoer の2 つの部分も含まれます。
  • ❓質問: CNN エンコーダーとトランスフォーマー エンコーダーを組み合わせて使用​​する理由は何ですか?
    • この質問では、論文では Content Encoder が両方を使用するとのみ述べられています。CNN 部分は、コンテンツ参照からコンパクトな特徴マップを学習するために使用されます。Transformer エンコーダは、テキスト コンテンツ表現を抽出するために使用されます。長距離の依存関係をキャプチャする Transformer の強力な機能のおかげで、Content Encdoer はグローバル コンテキスト コンテンツ機能を取得できます。これは、CNN + RNN の 2 つの部分を組み合わせた古典的な CRNN 構造を思い出させます。
      ここに画像の説明を挿入

コードと紙の対応

  • ペーパー構造のコア コードには 2 つの部分があり、1 つはモデル構築部分、もう 1 つはデータ セット処理部分です。
モデル部分を構築する
  • コードのこの部分はウェアハウスのmodels/model.pyにあります。最も重要な部分だけを抽出し、コメントを追加して説明します。残りの詳細はご自身で掘り出してください。
class SDT_Generator(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1,
                 wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=True, return_intermediate_dec=True):
        super(SDT_Generator, self).__init__()
        
        ### style encoder with dual heads
        # Feat_Encoder:对应论文中的CNN Encoder,用来提取图像经过CNN之后的特征,backbone选的是ResNet18
        self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
        
        # self.base_encoder:对应论文中Style Encoder的Transformer Encoderb部分
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None)
        
        writer_norm = nn.LayerNorm(d_model) if normalize_before else None
        glyph_norm = nn.LayerNorm(d_model) if normalize_before else None
 
        # writer_head和glyph_head分别对应论文中的Writer Head和Glyph Head
        # 从这里来看,这两个分支使用的是1层的Transformer Encoder结构
        self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)
        self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)

        ### content ecoder
        # content_encoder:对应论文中Content Encoder部分,
        # 从Content_TR源码来看,同样也是ResNet18作为CNN Encoder的backbone
        # Transformer Encoder部分用了3层的Transformer Encoder结构
        # 详情参见:https://github.com/dailenson/SDT/blob/1352b5cb779d47c5a8c87f6735e9dde94aa58f07/models/encoder.py#L8
        self.content_encoder = Content_TR(d_model, num_encoder_layers)

        ### decoder for receiving writer-wise and character-wise styles
        # 这里对应框图中Transformer Decoder中前后两个部分
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,
                                              return_intermediate=return_intermediate_dec)
        gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,
                                          return_intermediate=return_intermediate_dec)
        
        ### two mlps that project style features into the space where nce_loss is applied
        self.pro_mlp_writer = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))
        self.pro_mlp_character = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))

        self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)
        self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)
  
        # 这里位置嵌入来源于论文Attention is all you need.
        self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)        
        self._reset_parameters()

    # the shape of style_imgs is [B, 2*N, C, H, W] during training
    def forward(self, style_imgs, seq, char_img):
        batch_size, num_imgs, in_planes, h, w = style_imgs.shape

        # style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512]
        style_imgs = style_imgs.view(-1, in_planes, h, w)  # [B*2N, C:1, H, W]
        
        # 经过CNN Encoder
        style_embe = self.Feat_Encoder(style_imgs)  # [B*2N, C:512, 2, 2]

        anchor_num = num_imgs//2
        style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1)  # [4, B*2N, C:512]
        FEAT_ST_ENC = self.add_position(style_embe)

        memory = self.base_encoder(FEAT_ST_ENC)  # [4, B*2N, C]
        writer_memory = self.writer_head(memory)
        glyph_memory = self.glyph_head(memory)

        writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]
        glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]

        # writer-nce
        memory_fea = rearrange(writer_memory, 't b n c ->(t n) b c')  # [4*N, 2*B, C]
        compact_fea = torch.mean(memory_fea, 0) # [2*B, C]
        
        # compact_fea:[2*B, C:512] ->  nce_emb: [B, 2, C:128]
        pro_emb = self.pro_mlp_writer(compact_fea)
        query_emb = pro_emb[:batch_size, :]
        pos_emb = pro_emb[batch_size:, :]
        nce_emb = torch.stack((query_emb, pos_emb), 1) # [B, 2, C]
        nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)

        # glyph-nce
        patch_emb = glyph_memory[:, :batch_size]  # [4, B, N, C]
        
        # sample the positive pair
        anc, positive = self.random_double_sampling(patch_emb)
        n_channels = anc.shape[-1]
        anc = anc.reshape(batch_size, -1, n_channels)
        anc_compact = torch.mean(anc, 1, keepdim=True) 
        anc_compact = self.pro_mlp_character(anc_compact) # [B, 1, C]
        positive = positive.reshape(batch_size, -1, n_channels)
        positive_compact = torch.mean(positive, 1, keepdim=True)
        positive_compact = self.pro_mlp_character(positive_compact) # [B, 1, C]

        nce_emb_patch = torch.cat((anc_compact, positive_compact), 1) # [B, 2, C]
        nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)

        # input the writer-wise & character-wise styles into the decoder
        writer_style = memory_fea[:, :batch_size, :]  # [4*N, B, C]
        glyph_style = glyph_memory[:, :batch_size]  # [4, B, N, C]
        glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c') # [4*N, B, C]

        # QUERY: [char_emb, seq_emb]
        seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2)
        T, N, C = seq_emb.shape

        # ========================Content Encoder部分=========================
        char_emb = self.content_encoder(char_img) # [4, N, 512]
        char_emb = torch.mean(char_emb, 0) #[N, 512]
        char_emb = repeat(char_emb, 'n c -> t n c', t = 1)
        tgt = torch.cat((char_emb, seq_emb), 0) # [1+T], put the content token as the first token
        tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)
        tgt = self.add_position(tgt)

		# 注意这里的执行顺序,Content Encoder输出 → Writer Decoder → Glyph Decoder → Embedding to Sequence
        # [wri_dec_layers, T, B, C]
        wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)
        # [gly_dec_layers, T, B, C]
        hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)  

        h = hs.transpose(1, 2)[-1]  # B T C
        pred_sequence = self.EmbtoSeq(h)
        return pred_sequence, nce_emb, nce_emb_patch
データセットセクション
  • カシア_中国語
    data/CASIA_CHINESE
    ├── character_dict.pkl   # 词典
    ├── Chinese_content.pkl  # Content reference
    ├── test
    ├── test_style_samples
    ├── train
    ├── train_style_samples  # 1300个pkl,每个pkl中是同一个人写的各个字,长度不一致
    └── writer_dict.pkl
    
  • トレーニングセット内の単一のデータ形式の分析
    {
          
          
        'coords': torch.Tensor(coords),                # 写这个字,每一划的点阵
        'character_id': torch.Tensor([character_id]),  # content字的索引
        'writer_id': torch.Tensor([writer_id]),        # 某个人的style
        'img_list': torch.Tensor(img_list),            # 随机选中style的img_list
        'char_img': torch.Tensor(char_img),            # content字的图像
        'img_label': torch.Tensor([label_id]),         # style中图像的label
    }
    
  • 推論するとき:
    • 入力:
      • 15文字のスタイルイメージ
      • 生の入力文字
    • 出力: スタイルに属する元の文字

要約する

  1. Transformerの使い方は割と雑な気がします。もちろん、トランスフォーマーは本質的に粗いものです
  2. モデル69M(position_layer2_dim512_iter138k_test_acc0.9443.pth)は受け入れられやすいですが、私が思っていたトランスフォーマーシリーズとはかなり異なります。これは、自分自身の盲目的な認知を修正することとみなすことができます。
  3. einopsライブラリの使い方やセマンティックな操作などとても興味深く学びがいのある内容でした
  4. Loss of NCE (Noise Contrastive Estimation) について初めて知りましたが、主にクラスが多い場合に二項分類問題に変換する問題を解決します。

おすすめ

転載: blog.csdn.net/shiwanghualuo/article/details/131430113