Combing the CAN code of the latest papers in the field of handwritten mathematical formula recognition, and training with your own data set

foreword

Counting-Aware Network (CAN) - Handwritten Mathematical Formula Recognition Network is a paper published by TAL and Bai Xiang's team in 2022 that will be included in ECCV. This paper aims to alleviate most of the current handwritten mathematical formulas based on the attention mechanism When the recognition algorithm deals with long or complex mathematical formulas, it is prone to inaccurate attention. This paper enhances the model's perception of symbol positions by jointly optimizing the symbol counting task and the handwritten mathematical formula recognition task, and verifies that both the joint optimization and symbol counting results contribute to the improvement of formula recognition accuracy. The official code address is GitHub address

Overview of code structure

Download the official code and unzip it. The overall code structure is relatively clear and simple.
overall code structure
The entire code mainly includes the training code train.py, the data load code dataset.py, the model code is mainly in the models folder, and the model inference code inference. py

First look at the data load codeDefinition of HMERDatase class for data import

The main part of the data load code is the HMERDataset class. By default, it reads the pkl file containing the image matrix and the text file containing the image name and label, and then reads the text line of the label through the getitem (self, idx) function . , and obtain the image matrix at the same time, and then perform a simple normalization process on the image to convert it into a tensor. The specific code is as follows:

    def __getitem__(self, idx):
        name, *labels = self.labels[idx].strip().split()
        name = name.split('.')[0] if name.endswith('jpg') else name
        image = self.images[name]
        image = torch.Tensor(255-image) / 255
        image = image.unsqueeze(0)
        labels.append('eos')
        words = self.words.encode(labels)
        words = torch.LongTensor(words)
        return image, words

The next step is to make a shuffle of the HMERDataset class that reads images and labels, and then pass it to the DataLoader class of pytorch. It should be noted that when passing the HMERDataset class to DataLoader, a callback function is added. This function mainly adds an image and label mask. This mask is basically composed of 0, size, picture and label. Consistent size. The obtained mask is later passed in as input during model training. The specific code is as follows:

def collate_fn(batch_images):
    max_width, max_height, max_length = 0, 0, 0
    batch, channel = len(batch_images), batch_images[0][0].shape[0]
    proper_items = []
    for item in batch_images:
        if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[2] * max_height > 1600 * 320:
            continue
        max_height = item[0].shape[1] if item[0].shape[1] > max_height else max_height
        max_width = item[0].shape[2] if item[0].shape[2] > max_width else max_width
        max_length = item[1].shape[0] if item[1].shape[0] > max_length else max_length
        proper_items.append(item)

    images, image_masks = torch.zeros((len(proper_items), channel, max_height, max_width)), torch.zeros((len(proper_items), 1, max_height, max_width))
    labels, labels_masks = torch.zeros((len(proper_items), max_length)).long(), torch.zeros((len(proper_items), max_length))

    for i in range(len(proper_items)):
        _, h, w = proper_items[i][0].shape
        images[i][:, :h, :w] = proper_items[i][0]
        image_masks[i][:, :h, :w] = 1
        l = proper_items[i][1].shape[0]
        labels[i][:l] = proper_items[i][1]
        labels_masks[i][:l] = 1
    return images, image_masks, labels, labels_masks

Model overall code

The overall code of the model is relatively clear and tidy. The entry function is can.py. When you open it, you can see that can model overall code
the whole model basically includes the cnn feature extraction module, 2 counting_decoder modules (that is, the multi-scale counting module MSCM mentioned in the paper), A decoder module (i.e. combined counting attention decoder CCAD). Overall Architecture of the Model
The cnn feature extraction module, in the densenet.py file, there is not much to say, it is a densenet, input a picture, and output 684 feature maps.

Multi-scale counting module MSCM, in the counting.py file, this module is also relatively simple, the module input is the feature extracted by cnn, first do a trans_layer operation (first do convolution, batchNorm), and then do a channel_att operation (first do An AdaptiveAvgPool2d, then do two full-connection product + activation operations, and finally input * the feature map after the operation), and finally do a convolution + activation operation, transform the feature map size, and return.

class CountingDecoder(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size):
        super(CountingDecoder, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.trans_layer = nn.Sequential(
            nn.Conv2d(self.in_channel, 512, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
            nn.BatchNorm2d(512))
        self.channel_att = ChannelAtt(512, 16)
        self.pred_layer = nn.Sequential(
            nn.Conv2d(512, self.out_channel, kernel_size=1, bias=False),
            nn.Sigmoid())

    def forward(self, x, mask):
        b, c, h, w = x.size()
        x = self.trans_layer(x)
        x = self.channel_att(x)
        x = self.pred_layer(x)
        if mask is not None:
            x = x * mask
        x = x.view(b, self.out_channel, -1)
        x1 = torch.sum(x, dim=-1)
        return x1, x.view(b, self.out_channel, h, w)

The CCAD module of the attention decoder combined with counting is relatively complicated. It is mainly implemented in decoder.py. Its architecture is as follows. The input of
insert image description here
this module mainly includes the feature map extracted by densenet (hereinafter called cnn_features), and the multi-scale counting module MSCM's Counting Vector, position coding information, prediction information of the previous step, etc., the output is the current state yt.

Among them, the current state yt is obtained by adding four inputs, and then doing a fully connected layer + activation function. This is the content of this part of the code:

if self.params['dropout']:
	word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
else:
   	word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
word_prob = self.word_convert(word_out_state)

current_state is the last output state passed through the gru module to obtain the hidden state, and then obtained through the Linear layer;
word_weighted_embedding is the previous output state, obtained through the Linear layer;
counting_context_weighted is the Counting Vector output by the multi-scale counting module MSCM, obtained through the Linear layer;
word_context_weighted is the most troublesome, it is the output obtained through a word_attention module, and the input of this word attetion includes cnn_features, cnn_features obtained by adding cnn_features_trans, the hidden state output by gru and the coverage Attention of the previous state output after the product of encoder and position encoding (indicated by word_alpha_sum in the code), this part of the code is as follows:

word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, word_alpha_sum, images_mask)

Training and loss function modules

The training module is relatively routine and can basically be ignored.
The loss function of the model includes counting_loss that supervises the counting_preds output by the MSCM module. This loss function is a Smooth L1 loss. It mainly calculates three counting_preds1, counting_preds2, and counting_preds, and then sums them up.

counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
counting_preds = (counting_preds1 + counting_preds2) / 2
counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
                        + self.counting_loss(counting_preds, counting_labels)

Another loss function of the model is the cross entropy loss, which is to calculate the difference between the characters predicted by the model and the label, and then average

word_loss = self.cross(word_probs.contiguous().view(-1, word_probs.shape[-1]), labels.view(-1))
word_average_loss = (word_loss * labels_mask.view(-1)).sum() / (labels_mask.sum() + 1e-10) if self.use_label_mask else word_loss

The total loss of the model is obtained by adding counting_loss and word_average_loss.

train your own dataset

After understanding the general structure of the entire model, it is relatively simple to train your own data set on this model. There are two main ways (1) After reading the picture of your own data set, save it in pkl format, and the label is also It is the same format as the original model, which is a multi-line txt file, each line is the picture name + label; (2) If you do not want to convert the picture to pkl format, you need to generate a list file and store the picture address in the training set in In this list, it looks like the following:
Training picture address
the label is also a text file, and the content style is as follows:
input label
here is a little trick, because each character of the label of the handwritten formula is separated by a space, and the name of the picture and the label use A special character is used to distinguish them. Here I use the "#$" symbol to separate the image name and label. Of course, there is no problem with separating them with spaces, and they can be used normally.

After preparing the above two files, simply modify the code to train your own data set normally
Enter the code modification of load

at last

This paper designs a novel multi-scale counting module that is capable of multi-category symbol counting while only using formulas to identify raw annotations (i.e. LaTeX sequences) without symbol position annotations. By inserting the symbol counting module into the existing formula recognition network based on the encoder-decoder structure of the attention mechanism, the formula recognition accuracy of the existing model can be improved. In addition, it is also verified that the formula recognition task can also improve the accuracy of symbol counting through joint optimization.

In addition, the data used to train the handwritten formula recognition model is the real data made by myself (about 7w). If necessary, you can contact me by private message. A small amount of data styles can be downloaded and viewed in my resources.

Guess you like

Origin blog.csdn.net/weixin_42280271/article/details/128111352