Shanghai Shengtong team open source Branchformer in WeNet

As a leader in the interactive artificial intelligence market, Shanghai Shengtong Information Technology Co., Ltd. has strong technical advantages and outstanding product features. Based on the self-developed two core technologies of integrated communication and artificial intelligence, the company has created rich and highly standardized product modules to provide customers with efficient and stable product experience. The company's main business scenarios are smart city, smart travel, smart communication and smart finance. At the same time, the company is also actively developing other scenarios and innovative applications of products.

paper introduction

Branchformer is a new generation encoder structure proposed by Carnegie Mellon University with more flexible structure, stronger interpretability, and more flexible configuration. In the ESPnet framework, in the case of the same amount of parameters, the experimental results of multiple commonly used data sets (aishell, etc.) are equal to or better than the Conformer structure. Its article has been included in ICML2022. This article mainly explains its general structure and reproduces it in the WeNet framework.

Introduction: Since it was proposed, the Conformer structure has been widely used in the field of speech including ASR and other tasks due to its high efficiency, and has maintained state-of-the-art in multiple tasks. Compared with the Transformer structure, it can better capture local and global features. However, Conformer uses a serial method to pass the audio through the self-attention module and the convolution module in each encoder_layer in turn and pass it to the next layer. Even though, this approach has achieved very good results, but its interpretability may be a bit confusing. What is the relationship between local features and global features? How did they fit together and are they equally important? Or which of them plays a more important role?

With the above problems, a new encoder structure Branchformer was proposed. Compared with the structure of Conformer macaron sandwich stacking structure, Branchformer has made the following improvements:

  • A parallel double-branch structure is adopted. Among them, branch one uses the multiheaded self-attention mechanism to extract global features in the input sequence, and branch two introduces the cgMLP structure, which is intended to capture local features in the audio sequence.

  • The MLP with convolutional gating (cgMLP) module uses a combination of depthwise separable convolutions and linear gating units to learn feature representations in sequences.

  • Various feature combinations such as Concat and learnable parameter weighting

  • Stochastic Layer Skip, enhance the robustness of the model by randomly discarding the encoder_layer during training (added in the Espnet code, not mentioned in the paper)

model realization

Through reading the paper and source code, we found that the difference between Branchformer and Conformer is mainly in the way of feature extraction and combination in its encoder_layer, and the overall processing flow is not much different after zooming out. We refer to the code of Branchformer in ESPnet to complete its implementation in the WeNet framework.

cgMLP module

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        cache: torch.Tensor = torch.zeros((0, 0, 0))
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward cgMLP"""

        xs_pad = x

        # size -> linear_units
        xs_pad = self.channel_proj1(xs_pad)

        # linear_units -> linear_units/2
        xs_pad, new_cnn_cache = self.csgu(xs_pad, cache)

        # linear_units/2 -> size
        xs_pad = self.channel_proj2(xs_pad)

        out = xs_pad

        return out, new_cnn_cache

CSGU is the key in cgMLP. It first divides the input sequence into two according to the feature dimension, part of which will pass through layer norm and depth-wise convolution, and then perform element-wise multiplication with the other part to get the output. Since the cache is introduced in WeNet, we calculate and update cnn_cache here.

 def forward(
        self,
        x: torch.Tensor,
        cache: torch.Tensor = torch.zeros((0, 0, 0))
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward CSGU"""
        
        x_r, x_g = x.chunk(2, dim=-1)
        # exchange the temporal dimension and the feature dimension
        x_g = x_g.transpose(1, 2)  # (#batch, channels, time)

        if self.lorder > 0:
            if cache.size(2) == 0:  # cache_t == 0
                x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0)
            else:
                assert cache.size(0) == x_g.size(0)  # equal batch
                assert cache.size(1) == x_g.size(1)  # equal channel
                x_g = torch.cat((cache, x_g), dim=2)
            assert (x_g.size(2) > self.lorder)
            new_cache = x_g[:, :, -self.lorder:]
        else:
            # It's better we just return None if no cache is required,
            # However, for JIT export, here we just fake one tensor instead of
            # None.
            new_cache = torch.zeros((0, 0, 0), dtype=x_g.dtype, device=x_g.device)

        x_g = x_g.transpose(1, 2)
        x_g = self.norm(x_g)  # (N, T, D/2)
        x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)  # (N, T, D/2)
        if self.linear is not None:
            x_g = self.linear(x_g)

        x_g = self.act(x_g)
        out = x_r * x_g  # (N, T, D/2)
        out = self.dropout(out)
        return out, new_cache

Merge Two Branches merge branch feature

The author proposes three different feature fusion methods, direct concat, equal-weight linear fusion, and learnable weight fusion. The overall code does not change much, and only need to replace numpy and other modules to avoid affecting the torch.jit model export.

if self.merge_method == "concat":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(torch.cat([x1, x2], dim=-1))
                )
            elif self.merge_method == "learned_ave":
                if (
                    self.training
                    and self.attn_branch_drop_rate > 0
                    and torch.rand(1).item() < self.attn_branch_drop_rate
                ):
                    # Drop the attn branch
                    w1, w2 = torch.tensor(0.0), torch.tensor(1.0)
                else:
                    # branch1
                    score1 = (self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5)
                    score1 = score1.masked_fill(mask_pad.eq(0), -float('inf'))
                    score1 = torch.softmax(score1, dim=-1).masked_fill(
                        mask_pad.eq(0), 0.0
                    )

                    pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
                    weight1 = self.weight_proj1(pooled1)  # (batch, 1)

                    # branch2
                    score2 = (self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5)
                    score2 = score2.masked_fill(mask_pad.eq(0), -float('inf'))
                    score2 = torch.softmax(score2, dim=-1).masked_fill(
                        mask_pad.eq(0), 0.0
                    )

                    pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
                    weight2 = self.weight_proj2(pooled2)  # (batch, 1)

                    # normalize weights of two branches
                    merge_weights = torch.softmax(
                        torch.cat([weight1, weight2], dim=-1), dim=-1
                    )  # (batch, 2)
                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
                        -1
                    )  # (batch, 2, 1, 1)
                    w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)

                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(w1 * x1 + w2 * x2)
                )
            elif self.merge_method == "fixed_ave":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(
                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
                    )
                )
            else:
                raise RuntimeError(f"unknown merge method: {self.merge_method}")

In this paper, the author compares the influence of different merge operations on the model, and visualizes the weight distribution of learnable parameters in encoder_layer of different depths

Stochastic Layer Skip

Added Stochastic depth in ESPnet code, enable this option in configuration parameters to randomly skip some layers while training. This enables Branchformer to train deeper networks, and randomly skipping layers during training can speed up training and make the model more robust.

stoch_layer_coeff = 1.0
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training and self.stochastic_depth_rate > 0:
    skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
    stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)

Streaming inference

Although Branchformer uses two branches to calculate global and local features respectively, it is actually similar to Conformer in streaming calculation, and can calculate atten_cache and cnn_cache for update respectively. The method is the same as that of Conformer, and can basically be applied directly.

for i, layer in enumerate(self.encoders):
            # NOTE(xcsong): Before layer.forward
            #   shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
            #   shape(cnn_cache[i])       is (b=1, hidden-dim, cache_t2)
            xs, _, new_att_cache, new_cnn_cache = layer(
                xs, att_mask, pos_emb,
                att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
                cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
            )
            # NOTE(xcsong): After layer.forward
            #   shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
            #   shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
            r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
            r_cnn_cache.append(new_cnn_cache.unsqueeze(0))

Experimental results

We contributed a complete Branchformer training program on WeNet, and conducted related experiments on the aishell dataset for parameters such as encoder layer number and linear units.

model configuration attention attention_rescore ctc_prefix_beam_search ctc_greedy_search

24 layers + 2048 linear units

5.12 4.81 5.28 5.28
24 layers + 1024 linear units 5.33 4.88 5.41 5.40
12 layers + 2048 linear units 5.37 5.08 5.69 5.69

References

Branchformer:https://arxiv.org/abs/2207.02971

ESPnet:https://github.com/espnet/espnet/blob/master/espnet2/asr/encoder/branchformer_encoder.py

Guess you like

Origin blog.csdn.net/weixin_48827824/article/details/131477394