Convert PPOCRv3 model to pytorch

Preface

The PaddleOCRv3 version was released some time ago, and the detection and recognition models have been updated. The performance has been greatly improved. In line with the principle of prostitution when you can, I started prostitution on the first day it came out. Although the performance of the new model is compared to the previous one, A big improvement, but at first glance, the model structure is much more complicated, and it is a lot more troublesome to deploy. At this stage, the conversion of paddle framework to other deployment frameworks can only be achieved by converting paddle2onnx and then to other frameworks, so I plan to step out of the trap and provide import paddle as Torch version of the model: Transfer the model weight of the paddle framework to pytorch to provide more choices for the deployment plan. After switching to the pytorch framework, you can switch to other deployment methods from pytorch. Take the previous example: use pnnx to pytorch Model to ncnn model .

Comparison with previous model performance:
Insert image description here
The code implementation of this project is based on:

1. paddle2torch

Let’s talk about the conversion principle first. Because paddlepaddle and pytorch are both dynamic frameworks, the conversion is relatively simple. For the paddle model to be converted, we only need to use torch to rebuild the same network model structure, and then take out the weights of paddle one by one. Corresponding values ​​are assigned to each layer. It seems that the process is relatively simple, but after all, they are different frameworks, and some OP implementations are also different, so it is inevitable that there will be many pitfalls.

Before the conversion, let's first take a look at which modules PaddleOCRV3 has updated compared to the previous version of the model:
First is the detection model:

Detection module :

  1. LK-PAN: PAN structure with large receptive field
  2. DML: Teacher model mutual learning strategy
  3. RSE-FPN: FPN structure of residual attention mechanism

Identification module :

  • SVTR_LCNet: lightweight text recognition network
  • GTC: Attention guides CTC training strategy
  • TextConAug: Data augmentation strategy for mining text contextual information
  • TextRotNet: Self-supervised pre-trained model
  • UDML: Federated Mutual Learning Strategy
  • UIM: Unlabeled data mining solution

For details, please see the official technical report of PPOCRV3 . Here we only need to pay attention to those modules that need to be paid attention to during our conversion process.

2. Detection model conversion

The first is the detection module. The detection module has three parts to update. We only need to pay attention to RSE-FPN, because the first two are optimizations of the teacher model by distillation learning during the training process.

RSE-FPN (Residual Squeeze-and-Excitation FPN), as shown in the figure below, introduces the residual structure and the channel attention structure, replaces the convolutional layer in the FPN with the RSEConv layer of the channel attention structure, and further improves the representation of the feature map. ability. Considering that the number of FPN channels in the detection model of PP-OCRv2 is very small, only 96, if SEblock is directly used to replace the convolution in FPN, the features of some channels will be suppressed and the accuracy will decrease. The introduction of residual structure in RSEConv will alleviate the above problems and improve the text detection effect. Further update the FPN structure of the CML student model in PP-OCRv2 to RSE-FPN, and the hmean of the student model can be further improved from 84.3% to 85.4%:
Insert image description here
RSE-FPN pytorch code implementation:

class RSELayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
        super(RSELayer, self).__init__()
        self.out_channels = out_channels
        self.in_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.out_channels,
            kernel_size=kernel_size,
            padding=int(kernel_size // 2),
            bias=False)
        self.se_block = SEBlock(self.out_channels,self.out_channels)
        self.shortcut = shortcut

    def forward(self, ins):
        x = self.in_conv(ins)
        if self.shortcut:
            out = x + self.se_block(x)
        else:
            out = self.se_block(x)
        return out


class RSEFPN(nn.Module):
    def __init__(self, in_channels, out_channels=256, shortcut=True, **kwargs):
        super(RSEFPN, self).__init__()
        self.out_channels = out_channels
        self.ins_conv = nn.ModuleList()
        self.inp_conv = nn.ModuleList()

        for i in range(len(in_channels)):
            self.ins_conv.append(
                RSELayer(
                    in_channels[i],
                    out_channels,
                    kernel_size=1,
                    shortcut=shortcut))
            self.inp_conv.append(
                RSELayer(
                    out_channels,
                    out_channels // 4,
                    kernel_size=3,
                    shortcut=shortcut))

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2) + y

    def _upsample_cat(self, p2, p3, p4, p5):
        p3 = F.interpolate(p3, scale_factor=2)
        p4 = F.interpolate(p4, scale_factor=4)
        p5 = F.interpolate(p5, scale_factor=8)
        return torch.cat([p5, p4, p3, p2], dim=1)

    def forward(self, x):
        c2, c3, c4, c5 = x

        in5 = self.ins_conv[3](c5)
        in4 = self.ins_conv[2](c4)
        in3 = self.ins_conv[1](c3)
        in2 = self.ins_conv[0](c2)

        out4 = self._upsample_add(in5, in4)
        out3 = self._upsample_add(out4, in3)
        out2 = self._upsample_add(out3, in2)

        p5 = self.inp_conv[3](in5)
        p4 = self.inp_conv[2](out4)
        p3 = self.inp_conv[1](out3)
        p2 = self.inp_conv[0](out2)

        x = self._upsample_cat(p2, p3, p4, p5)
        return x

The complete network is divided into three parts: Backbone (MobileNetV3), Neck (RSEFPN), and Head (DBHead). With the help of the PytorchOCR project, these three parts are implemented separately, and then the network is built.

from torch import nn
from det.DetMobilenetV3 import MobileNetV3
from det.DB_fpn import DB_fpn,RSEFPN,LKPAN
from det.DetDbHead import DBHead

backbone_dict = {
    
    'MobileNetV3': MobileNetV3}
neck_dict = {
    
    'DB_fpn': DB_fpn,'RSEFPN':RSEFPN,'LKPAN':LKPAN}
head_dict = {
    
    'DBHead': DBHead}

class DetModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'DetModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":
    db_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True),
        neck=AttrDict(type='RSEFPN', out_channels=96),
        head=AttrDict(type='DBHead')
    )

    model = DetModel(db_config)

Then use paddleOCRV3's text detection training model (note that you can only use the training model), take out the model's weights and corresponding key values, and initialize them into the torch model respectively. The complete code is linked at the end of the article.

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        state_dict[trModule_state[i]] = torch.from_numpy(state[key])

    return state_dict

3. Identification model conversion

The conversion of the recognition model is much more complicated than the detection model. The recognition module of PP-OCRv3 is optimized based on the text recognition algorithm SVTR. SVTR no longer uses the RNN structure. By introducing the Transformers structure, it can more effectively mine the contextual information of text line images, thereby improving text recognition capabilities. Among the many recognition optimizations above, we only need to focus on the first optimization: SVTR_LCNet, and the others are The training techniques used in the training process do not need to be used in the model conversion process.

SVTR_LCNet is a lightweight text recognition network that integrates the Transformer-based SVTR network and the lightweight CNN network PP-LCNet for text recognition tasks. The overall network is as follows: Using this network, the prediction speed is better than PP-
Insert image description here
OCRv2 The recognition model is 20%, but because the distillation strategy is not used, the recognition model is slightly less effective. In addition, the normalization height of the input image is further increased from 32 to 48, and the prediction speed is slightly slower, but the model effect is greatly improved, and the recognition accuracy reaches 73.98% (+2.08%), which is close to the recognition model effect of PP-OCRv2 using the distillation strategy. Ablation experiment process:
Insert image description here

Similarly, the torch network model is constructed based on the paddle's recognition network structure. The model is divided into three parts: Backbone (LCNet), Encoder (SVTR Transformers), and Head (MultiHead). The Encoder part uses SVTR's Transformers structure encoding:

class EncoderWithSVTR(nn.Module):
    def __init__(
            self,
            in_channels,
            dims=64,  # XS
            depth=2,
            hidden_dims=120,
            use_guide=False,
            num_heads=8,
            qkv_bias=True,
            mlp_ratio=2.0,
            drop_rate=0.1,
            attn_drop_rate=0.1,
            drop_path=0.,
            qk_scale=None):
        super(EncoderWithSVTR, self).__init__()
        self.depth = depth
        self.use_guide = use_guide
        self.conv1 = ConvBNLayer(
            in_channels, in_channels // 8, padding=1)
        self.conv2 = ConvBNLayer(
            in_channels // 8, hidden_dims, kernel_size=1)

        self.svtr_block = nn.ModuleList([
            Block(
                dim=hidden_dims,
                num_heads=num_heads,
                mixer='Global',
                HW=None,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                act_layer="Swish",
                attn_drop=attn_drop_rate,
                drop_path=drop_path,
                norm_layer='nn.LayerNorm',
                epsilon=1e-05,
                prenorm=False) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
        self.conv3 = ConvBNLayer(
            hidden_dims, in_channels, kernel_size=1)
        # last conv-nxn, the input is concat of input tensor and conv3 output tensor
        self.conv4 = ConvBNLayer(
            2 * in_channels, in_channels // 8, padding=1)

        self.conv1x1 = ConvBNLayer(
            in_channels // 8, dims, kernel_size=1)
        self.out_channels = dims
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward(self, x):
        # for use guide
        if self.use_guide:
            z = x.clone()
            z.stop_gradient = True
        else:
            z = x
        # for short cut
        h = z
        # reduce dim
        z = self.conv1(z)
        z = self.conv2(z)
        # SVTR global block
        B, C, H, W = z.shape
        z = z.flatten(2).permute([0, 2, 1])
        for blk in self.svtr_block:
            z = blk(z)
        z = self.norm(z)
        # last stage
        z = z.reshape([-1, H, W, C]).permute([0, 3, 1, 2])
        z = self.conv3(z)
        z = torch.cat((h, z), dim=1)
        z = self.conv1x1(self.conv4(z))
        return z

The Head part is a multi-head, but only CTCHead is actually used during inference, and the SARHead during training is removed, so this part does not need to be added during network construction.

class MultiHead(nn.Module):
    def __init__(self, in_channels, **kwargs):
        super().__init__()
        self.out_c = kwargs.get('n_class')
        self.head_list = kwargs.get('head_list')
        self.gtc_head = 'sar'
        # assert len(self.head_list) >= 2
        for idx, head_name in enumerate(self.head_list):
            # name = list(head_name)[0]
            name = head_name
            if name == 'SARHead':
                # sar head
                sar_args = self.head_list[name]
                self.sar_head = eval(name)(in_channels=in_channels, out_channels=self.out_c, **sar_args)
            if name == 'CTC':
                # ctc neck
                self.encoder_reshape = Im2Seq(in_channels)
                neck_args = self.head_list[name]['Neck']
                encoder_type = neck_args.pop('name')
                self.encoder = encoder_type
                self.ctc_encoder = SequenceEncoder(in_channels=in_channels,encoder_type=encoder_type, **neck_args)
                # ctc head
                head_args = self.head_list[name]
                self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels,n_class=self.out_c, **head_args)
            else:
                raise NotImplementedError(
                    '{} is not supported in MultiHead yet'.format(name))

    def forward(self, x, targets=None):
        ctc_encoder = self.ctc_encoder(x)
        ctc_out = self.ctc_head(ctc_encoder, targets)
        head_out = dict()
        head_out['ctc'] = ctc_out
        head_out['ctc_neck'] = ctc_encoder
        return ctc_out                          # infer   不经过SAR直接返回
        
        # # eval mode
        # print(not self.training)
        # if not self.training:                 # training
        #     return ctc_out
        # if self.gtc_head == 'sar':
        #     sar_out = self.sar_head(x, targets[1:])
        #     head_out['sar'] = sar_out
        #     return head_out
        # else:
        #     return head_out

Complete network build:

from torch import nn

from rec.RNN import SequenceEncoder, Im2Seq,Im2Im
from rec.RecSVTR import SVTRNet
from rec.RecMv1_enhance import MobileNetV1Enhance

from rec.RecCTCHead import CTC,MultiHead

backbone_dict = {
    
    "SVTR":SVTRNet,"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {
    
    'PPaddleRNN': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {
    
    'CTC': CTC,'Multi':MultiHead}


class RecModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'RecModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":

    rec_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5,last_conv_stride=[1,2],last_pool_type='avg'),
        neck=AttrDict(type='None'),
   head=AttrDict(type='Multi',head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)),
                                                       # SARHead=AttrDict(enc_dim=512,max_text_length=70)
                                                      ),
                      n_class=6625)
    )

    model = RecModel(rec_config)

Similarly, load the recognition training model of paddleocrv3, take out the key value corresponding to the weight, and initialize it into the torch model. However, what needs to be noted here is the weight shape problem of the full link layer in paddle and the full link layer in torch. When the link layer is assigned to the full link layer of torch, the weights need to be transposed (transpose():

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))
    keys = ["head.ctc_encoder.encoder.svtr_block.0.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc2.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc2.weight",
            "head.ctc_head.fc.weight",
            ]

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        if i > 238:
            j = i-239
            if j <= 195:
                if trModule_state[j] in keys:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key]).transpose(0,1)
                else:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key])

    return state_dict

PaddleOCR training model link PaddleOCR :
Insert image description here
The complete code has been thrown on github, welcome to learn from it.

paddle2torch_PPOCRv3

Guess you like

Origin blog.csdn.net/qq_39056987/article/details/124921515