As a leader in the interactive artificial intelligence market, Shanghai Shengtong Information Technology Co., Ltd. has strong technical advantages and outstanding product features. Based on the self-developed two core technologies of integrated communication and artificial intelligence, the company has created rich and highly standardized product modules to provide customers with efficient and stable product experience. The company's main business scenarios are smart city, smart travel, smart communication and smart finance. At the same time, the company is also actively developing other scenarios and innovative applications of products.
paper introduction
Branchformer is a new generation encoder structure proposed by Carnegie Mellon University with more flexible structure, stronger interpretability, and more flexible configuration. In the ESPnet framework, in the case of the same amount of parameters, the experimental results of multiple commonly used data sets (aishell, etc.) are equal to or better than the Conformer structure. Its article has been included in ICML2022. This article mainly explains its general structure and reproduces it in the WeNet framework.
Introduction: Since it was proposed, the Conformer structure has been widely used in the field of speech including ASR and other tasks due to its high efficiency, and has maintained state-of-the-art in multiple tasks. Compared with the Transformer structure, it can better capture local and global features. However, Conformer uses a serial method to pass the audio through the self-attention module and the convolution module in each encoder_layer in turn and pass it to the next layer. Even though, this approach has achieved very good results, but its interpretability may be a bit confusing. What is the relationship between local features and global features? How did they fit together and are they equally important? Or which of them plays a more important role?
With the above problems, a new encoder structure Branchformer was proposed. Compared with the structure of Conformer macaron sandwich stacking structure, Branchformer has made the following improvements:
-
A parallel double-branch structure is adopted. Among them, branch one uses the multiheaded self-attention mechanism to extract global features in the input sequence, and branch two introduces the cgMLP structure, which is intended to capture local features in the audio sequence.
-
The MLP with convolutional gating (cgMLP) module uses a combination of depthwise separable convolutions and linear gating units to learn feature representations in sequences.
-
Various feature combinations such as Concat and learnable parameter weighting
-
Stochastic Layer Skip, enhance the robustness of the model by randomly discarding the encoder_layer during training (added in the Espnet code, not mentioned in the paper)
model realization
Through reading the paper and source code, we found that the difference between Branchformer and Conformer is mainly in the way of feature extraction and combination in its encoder_layer, and the overall processing flow is not much different after zooming out. We refer to the code of Branchformer in ESPnet to complete its implementation in the WeNet framework.
cgMLP module
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
cache: torch.Tensor = torch.zeros((0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward cgMLP"""
xs_pad = x
# size -> linear_units
xs_pad = self.channel_proj1(xs_pad)
# linear_units -> linear_units/2
xs_pad, new_cnn_cache = self.csgu(xs_pad, cache)
# linear_units/2 -> size
xs_pad = self.channel_proj2(xs_pad)
out = xs_pad
return out, new_cnn_cache
CSGU is the key in cgMLP. It first divides the input sequence into two according to the feature dimension, part of which will pass through layer norm and depth-wise convolution, and then perform element-wise multiplication with the other part to get the output. Since the cache is introduced in WeNet, we calculate and update cnn_cache here.
def forward(
self,
x: torch.Tensor,
cache: torch.Tensor = torch.zeros((0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward CSGU"""
x_r, x_g = x.chunk(2, dim=-1)
# exchange the temporal dimension and the feature dimension
x_g = x_g.transpose(1, 2) # (#batch, channels, time)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x_g.size(0) # equal batch
assert cache.size(1) == x_g.size(1) # equal channel
x_g = torch.cat((cache, x_g), dim=2)
assert (x_g.size(2) > self.lorder)
new_cache = x_g[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x_g.dtype, device=x_g.device)
x_g = x_g.transpose(1, 2)
x_g = self.norm(x_g) # (N, T, D/2)
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
if self.linear is not None:
x_g = self.linear(x_g)
x_g = self.act(x_g)
out = x_r * x_g # (N, T, D/2)
out = self.dropout(out)
return out, new_cache
Merge Two Branches merge branch feature
The author proposes three different feature fusion methods, direct concat, equal-weight linear fusion, and learnable weight fusion. The overall code does not change much, and only need to replace numpy and other modules to avoid affecting the torch.jit model export.
if self.merge_method == "concat":
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(torch.cat([x1, x2], dim=-1))
)
elif self.merge_method == "learned_ave":
if (
self.training
and self.attn_branch_drop_rate > 0
and torch.rand(1).item() < self.attn_branch_drop_rate
):
# Drop the attn branch
w1, w2 = torch.tensor(0.0), torch.tensor(1.0)
else:
# branch1
score1 = (self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5)
score1 = score1.masked_fill(mask_pad.eq(0), -float('inf'))
score1 = torch.softmax(score1, dim=-1).masked_fill(
mask_pad.eq(0), 0.0
)
pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size)
weight1 = self.weight_proj1(pooled1) # (batch, 1)
# branch2
score2 = (self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5)
score2 = score2.masked_fill(mask_pad.eq(0), -float('inf'))
score2 = torch.softmax(score2, dim=-1).masked_fill(
mask_pad.eq(0), 0.0
)
pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size)
weight2 = self.weight_proj2(pooled2) # (batch, 1)
# normalize weights of two branches
merge_weights = torch.softmax(
torch.cat([weight1, weight2], dim=-1), dim=-1
) # (batch, 2)
merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
-1
) # (batch, 2, 1, 1)
w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1)
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(w1 * x1 + w2 * x2)
)
elif self.merge_method == "fixed_ave":
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(
(1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
)
)
else:
raise RuntimeError(f"unknown merge method: {self.merge_method}")
In this paper, the author compares the influence of different merge operations on the model, and visualizes the weight distribution of learnable parameters in encoder_layer of different depths
Stochastic Layer Skip
Added Stochastic depth in ESPnet code, enable this option in configuration parameters to randomly skip some layers while training. This enables Branchformer to train deeper networks, and randomly skipping layers during training can speed up training and make the model more robust.
stoch_layer_coeff = 1.0
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
Streaming inference
Although Branchformer uses two branches to calculate global and local features respectively, it is actually similar to Conformer in streaming calculation, and can calculate atten_cache and cnn_cache for update respectively. The method is the same as that of Conformer, and can basically be applied directly.
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
Experimental results
We contributed a complete Branchformer training program on WeNet, and conducted related experiments on the aishell dataset for parameters such as encoder layer number and linear units.
model configuration | attention | attention_rescore | ctc_prefix_beam_search | ctc_greedy_search |
---|---|---|---|---|
24 layers + 2048 linear units |
5.12 | 4.81 | 5.28 | 5.28 |
24 layers + 1024 linear units | 5.33 | 4.88 | 5.41 | 5.40 |
12 layers + 2048 linear units | 5.37 | 5.08 | 5.69 | 5.69 |
References
Branchformer:https://arxiv.org/abs/2207.02971
ESPnet:https://github.com/espnet/espnet/blob/master/espnet2/asr/encoder/branchformer_encoder.py