PPOCRv3 モデルを pytorch に変換する

序文

PaddleOCRv3 バージョンはしばらく前にリリースされ、検出および認識モデルが更新されました。パフォーマンスは大幅に向上しました。できる限り売春するという原則に従って、私はリリース初日から売春を始めました。新しいモデルのパフォーマンスを以前のモデルと比較すると、大幅な改善が見られますが、一見すると、モデルの構造がはるかに複雑になり、デプロイがはるかに面倒になります。この段階では、パドル フレームワークを他のものに変換する必要があります。デプロイ フレームワークは、paddle2onnx を変換してから他のフレームワークに変換することによってのみ実現できるため、罠から抜け出して、import paddle をモデルのトーチバージョンとして提供する予定です。パドル フレームワークのモデルの重みを pytorch に転送して、より多くの選択肢を提供します。デプロイメント プラン。pytorch フレームワークに切り替えた後、pytorch から他のデプロイメント方法に切り替えることができます。前の例を見てみましょう: use pnnx to pytorch Model to ncnn model

以前のモデルのパフォーマンスとの比較:
ここに画像の説明を挿入します
このプロジェクトのコード実装は以下に基づいています。

1.パドル2トーチ

まず変換原理について話しましょう. paddlepaddle と pytorch はどちらも動的フレームワークであるため、変換は比較的簡単です. パドル モデルを変換するには、torch を使用して同じネットワーク モデル構造を再構築し、その後、パドルの重みを 1 つずつ、対応する値が各レイヤーに割り当てられます。プロセスは比較的単純なように見えますが、結局のところ、それらは異なるフレームワークであり、一部のOP実装も異なるため、多くの落とし穴があることは避けられません。

変換の前に、まず PaddleOCRV3 がモデルの以前のバージョンと比較してどのモジュールが更新されたかを見てみましょう。
まず検出モデルです。

検出モジュール:

  1. LK-PAN: 大きな受容野を持つPAN構造
  2. DML: 教師モデルの相互学習戦略
  3. RSE-FPN: 残留注意メカニズムの FPN 構造

識別モジュール:

  • SVTR_LCNet: 軽量のテキスト認識ネットワーク
  • GTC: 注意が CTC トレーニング戦略をガイドします
  • TextConAug: テキストのコンテキスト情報をマイニングするためのデータ拡張戦略
  • TextRotNet: 自己教師ありの事前トレーニング済みモデル
  • UDML: フェデレーション相互学習戦略
  • UIM: ラベルなしデータ マイニング ソリューション

詳細については、 PPOCRV3 の公式テクニカル レポートを参照してください。ここでは、変換プロセス中に注意する必要があるモジュールのみに注意を払う必要があります。

2. 検出モデルの変換

1 つ目は検出モジュールです。検出モジュールには更新する 3 つの部分があります。最初の 2 つはトレーニング プロセス中の蒸留学習による教師モデルの最適化であるため、RSE-FPN にのみ注目する必要があります。

RSE-FPN (Residual Squeeze-and-Exciltation FPN) は、下図に示すように、残差構造とチャネル アテンション構造を導入し、FPN の畳み込み層をチャネル アテンション構造の RSEConv 層に置き換え、さらに改良したものです。特徴マップの表現。PP-OCRv2 ​​の検出モデルの FPN チャネル数が 96 と非常に少ないことを考慮すると、FPN の畳み込みを置き換えるために SEblock を直接使用すると、一部のチャネルの特徴が抑制され、精度が低下します。RSEConv に残差構造を導入すると、上記の問題が軽減され、テキスト検出効果が向上します。PP-OCRv2 ​​の CML 学生モデルの FPN 構造を RSE-FPN にさらに更新すると、学生モデルの hmean が 84.3% から 85.4% にさらに改善されます。 RSE-FPN pytorch コードの実装
ここに画像の説明を挿入します
:

class RSELayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
        super(RSELayer, self).__init__()
        self.out_channels = out_channels
        self.in_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.out_channels,
            kernel_size=kernel_size,
            padding=int(kernel_size // 2),
            bias=False)
        self.se_block = SEBlock(self.out_channels,self.out_channels)
        self.shortcut = shortcut

    def forward(self, ins):
        x = self.in_conv(ins)
        if self.shortcut:
            out = x + self.se_block(x)
        else:
            out = self.se_block(x)
        return out


class RSEFPN(nn.Module):
    def __init__(self, in_channels, out_channels=256, shortcut=True, **kwargs):
        super(RSEFPN, self).__init__()
        self.out_channels = out_channels
        self.ins_conv = nn.ModuleList()
        self.inp_conv = nn.ModuleList()

        for i in range(len(in_channels)):
            self.ins_conv.append(
                RSELayer(
                    in_channels[i],
                    out_channels,
                    kernel_size=1,
                    shortcut=shortcut))
            self.inp_conv.append(
                RSELayer(
                    out_channels,
                    out_channels // 4,
                    kernel_size=3,
                    shortcut=shortcut))

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2) + y

    def _upsample_cat(self, p2, p3, p4, p5):
        p3 = F.interpolate(p3, scale_factor=2)
        p4 = F.interpolate(p4, scale_factor=4)
        p5 = F.interpolate(p5, scale_factor=8)
        return torch.cat([p5, p4, p3, p2], dim=1)

    def forward(self, x):
        c2, c3, c4, c5 = x

        in5 = self.ins_conv[3](c5)
        in4 = self.ins_conv[2](c4)
        in3 = self.ins_conv[1](c3)
        in2 = self.ins_conv[0](c2)

        out4 = self._upsample_add(in5, in4)
        out3 = self._upsample_add(out4, in3)
        out2 = self._upsample_add(out3, in2)

        p5 = self.inp_conv[3](in5)
        p4 = self.inp_conv[2](out4)
        p3 = self.inp_conv[1](out3)
        p2 = self.inp_conv[0](out2)

        x = self._upsample_cat(p2, p3, p4, p5)
        return x

完全なネットワークは、バックボーン (MobileNetV3)、ネック (RSEFPN)、ヘッド (DBHead) の 3 つの部分に分かれており、 PytorchOCRプロジェクトを利用して、これら 3 つの部分が個別に実装され、ネットワークが構築されます

from torch import nn
from det.DetMobilenetV3 import MobileNetV3
from det.DB_fpn import DB_fpn,RSEFPN,LKPAN
from det.DetDbHead import DBHead

backbone_dict = {
    
    'MobileNetV3': MobileNetV3}
neck_dict = {
    
    'DB_fpn': DB_fpn,'RSEFPN':RSEFPN,'LKPAN':LKPAN}
head_dict = {
    
    'DBHead': DBHead}

class DetModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'DetModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":
    db_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True),
        neck=AttrDict(type='RSEFPN', out_channels=96),
        head=AttrDict(type='DBHead')
    )

    model = DetModel(db_config)

次に、paddleOCRV3 のテキスト検出トレーニング モデルを使用し (トレーニング モデルのみを使用できることに注意してください)、モデルの重みと対応するキー値を取り出し、それらをそれぞれトーチ モデルに初期化します。完全なコードは記事の最後にリンクされています。

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        state_dict[trModule_state[i]] = torch.from_numpy(state[key])

    return state_dict

3. 識別モデルの変換

認識モデルの変換は検出モデルよりもはるかに複雑ですが、PP-OCRv3 の認識モジュールはテキスト認識アルゴリズム SVTR に基づいて最適化されています。SVTR は RNN 構造を使用しなくなりました。Transformers 構造を導入することで、テキスト行画像のコンテキスト情報をより効率的にマイニングできるようになり、テキスト認識機能が向上します。上記の多くの認識最適化のうち、最初の最適化だけに注目する必要があります。 SVTR_LCNet などは、トレーニング プロセスで使用されるトレーニング手法をモデル変換プロセスで使用する必要はありません。


ここに画像の説明を挿入します
SVTR_LCNet は、Transformer ベースの SVTR ネットワークと、テキスト認識タスク用の軽量 CNN ネットワーク PP-LCNet を統合した軽量テキスト認識ネットワークです。全体的なネットワークは次のとおりです。このネットワークを使用すると、予測速度は PP- OCRv2​​よりも優れています。モデルは 20% ですが、蒸留戦略が使用されていないため、認識モデルの効果はわずかに低くなります。さらに、入力画像の正規化高さが 32 から 48 にさらに増加し​​、予測速度が若干遅くなりますが、モデル効果は大幅に向上し、認識精度は 73.98% (+2.08%) に達し、それに近い値になります。蒸留戦略を使用した PP-OCRv2 ​​の認識モデル効果へのアブレーション実験プロセス:
ここに画像の説明を挿入します

同様に、トーチ ネットワーク モデルは、パドルの認識ネットワーク構造に基づいて構築されます。モデルは、バックボーン (LCNet)、エンコーダー (SVTR Transformers)、およびヘッド (MultiHead) の 3 つの部分に分かれています。エンコーダー部分は、SVTR のトランスフォーマー構造エンコーディングを使用します。

class EncoderWithSVTR(nn.Module):
    def __init__(
            self,
            in_channels,
            dims=64,  # XS
            depth=2,
            hidden_dims=120,
            use_guide=False,
            num_heads=8,
            qkv_bias=True,
            mlp_ratio=2.0,
            drop_rate=0.1,
            attn_drop_rate=0.1,
            drop_path=0.,
            qk_scale=None):
        super(EncoderWithSVTR, self).__init__()
        self.depth = depth
        self.use_guide = use_guide
        self.conv1 = ConvBNLayer(
            in_channels, in_channels // 8, padding=1)
        self.conv2 = ConvBNLayer(
            in_channels // 8, hidden_dims, kernel_size=1)

        self.svtr_block = nn.ModuleList([
            Block(
                dim=hidden_dims,
                num_heads=num_heads,
                mixer='Global',
                HW=None,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                act_layer="Swish",
                attn_drop=attn_drop_rate,
                drop_path=drop_path,
                norm_layer='nn.LayerNorm',
                epsilon=1e-05,
                prenorm=False) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
        self.conv3 = ConvBNLayer(
            hidden_dims, in_channels, kernel_size=1)
        # last conv-nxn, the input is concat of input tensor and conv3 output tensor
        self.conv4 = ConvBNLayer(
            2 * in_channels, in_channels // 8, padding=1)

        self.conv1x1 = ConvBNLayer(
            in_channels // 8, dims, kernel_size=1)
        self.out_channels = dims
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward(self, x):
        # for use guide
        if self.use_guide:
            z = x.clone()
            z.stop_gradient = True
        else:
            z = x
        # for short cut
        h = z
        # reduce dim
        z = self.conv1(z)
        z = self.conv2(z)
        # SVTR global block
        B, C, H, W = z.shape
        z = z.flatten(2).permute([0, 2, 1])
        for blk in self.svtr_block:
            z = blk(z)
        z = self.norm(z)
        # last stage
        z = z.reshape([-1, H, W, C]).permute([0, 3, 1, 2])
        z = self.conv3(z)
        z = torch.cat((h, z), dim=1)
        z = self.conv1x1(self.conv4(z))
        return z

Head部分はマルチヘッドですが、実際に推論時に使用するのはCTCHeadのみで、学習時のSARHeadは削除されているため、ネットワーク構築時にこの部分を追加する必要はありません。

class MultiHead(nn.Module):
    def __init__(self, in_channels, **kwargs):
        super().__init__()
        self.out_c = kwargs.get('n_class')
        self.head_list = kwargs.get('head_list')
        self.gtc_head = 'sar'
        # assert len(self.head_list) >= 2
        for idx, head_name in enumerate(self.head_list):
            # name = list(head_name)[0]
            name = head_name
            if name == 'SARHead':
                # sar head
                sar_args = self.head_list[name]
                self.sar_head = eval(name)(in_channels=in_channels, out_channels=self.out_c, **sar_args)
            if name == 'CTC':
                # ctc neck
                self.encoder_reshape = Im2Seq(in_channels)
                neck_args = self.head_list[name]['Neck']
                encoder_type = neck_args.pop('name')
                self.encoder = encoder_type
                self.ctc_encoder = SequenceEncoder(in_channels=in_channels,encoder_type=encoder_type, **neck_args)
                # ctc head
                head_args = self.head_list[name]
                self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels,n_class=self.out_c, **head_args)
            else:
                raise NotImplementedError(
                    '{} is not supported in MultiHead yet'.format(name))

    def forward(self, x, targets=None):
        ctc_encoder = self.ctc_encoder(x)
        ctc_out = self.ctc_head(ctc_encoder, targets)
        head_out = dict()
        head_out['ctc'] = ctc_out
        head_out['ctc_neck'] = ctc_encoder
        return ctc_out                          # infer   不经过SAR直接返回
        
        # # eval mode
        # print(not self.training)
        # if not self.training:                 # training
        #     return ctc_out
        # if self.gtc_head == 'sar':
        #     sar_out = self.sar_head(x, targets[1:])
        #     head_out['sar'] = sar_out
        #     return head_out
        # else:
        #     return head_out

完全なネットワーク構築:

from torch import nn

from rec.RNN import SequenceEncoder, Im2Seq,Im2Im
from rec.RecSVTR import SVTRNet
from rec.RecMv1_enhance import MobileNetV1Enhance

from rec.RecCTCHead import CTC,MultiHead

backbone_dict = {
    
    "SVTR":SVTRNet,"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {
    
    'PPaddleRNN': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {
    
    'CTC': CTC,'Multi':MultiHead}


class RecModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'RecModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":

    rec_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5,last_conv_stride=[1,2],last_pool_type='avg'),
        neck=AttrDict(type='None'),
   head=AttrDict(type='Multi',head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)),
                                                       # SARHead=AttrDict(enc_dim=512,max_text_length=70)
                                                      ),
                      n_class=6625)
    )

    model = RecModel(rec_config)

同様にpaddleocrv3の認識学習モデルを読み込み、重みに相当するキー値を取り出してトーチモデルに初期化しますが、ここで注意が必要なのはpaddleのフルリンク層の重み形状問題と、トーチの完全なリンク層リンク層がトーチの完全なリンク層に割り当てられている場合、重みを転置する必要があります (transpose():

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))
    keys = ["head.ctc_encoder.encoder.svtr_block.0.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc2.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc2.weight",
            "head.ctc_head.fc.weight",
            ]

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        if i > 238:
            j = i-239
            if j <= 195:
                if trModule_state[j] in keys:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key]).transpose(0,1)
                else:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key])

    return state_dict

PaddleOCR トレーニング モデルのリンクPaddleOCR :
ここに画像の説明を挿入します
完全なコードは github にスローされています。そこから学ぶことを歓迎します。

paddle2torch_PPOCRv3

おすすめ

転載: blog.csdn.net/qq_39056987/article/details/124921515