Comprehensive analysis of the first open source MoE large model Mixtral 8x7B: from principle analysis to code interpretation

Preface

On December 8, 2023, Mistral AI launched a magnet link on the X platform (of course, many people later opened it and found that it was a seed of nearly 87 GB)

It seems that the architecture of Mixtral 8x7B is very similar to the previously rumored GPT-4 architecture ( much like the same solution of the rumored GPT-4 ), but the "shrunk version": 

  • 8 total experts instead of 16 (halved) 
  • 7B parameters per expert instead of 166B (24x reduction)
  • 42B total parameters (estimated) instead of 1.8T (42x reduction)
  • Same 32K context as original GPT-4

Within 24 hours after its release, a developer had already created an online experience website: https://replicate.com/nateraw/mixtral-8x7b-32kseqlen

The OpenAI team has been tight-lipped about the parameter quantities and training details of GPT-4. Earlier, someone broke the news that GPT-4 uses an integrated system composed of 8 expert models. Later, there were rumors that ChatGPT was only a model with tens of billions of parameters (probably around 20 billion).

The rumors cannot be proven, but Mixtral 8x7B may provide an open source option that is "very close to GPT-4". Hereby, this article will provide a comprehensive analysis: from principle analysis to code interpretation ( before this article, there was no information like this article) so fine )

Part 1 The first open source MoE large model Mixtral 8x7B

1.1 Overall architecture and model details of Mixtral 8x7B

Two days later, on December 11, 2023, the Mistral AI team officially released Mixtral 8x7B, which outperformed Llama 2 70B in most benchmark tests, increased inference speed by 6 times, and matched or outperformed it in most standard benchmark tests. In GPT3.5

To avoid ambiguity, as a supplementary explanation, the Mistral AI team has currently released a total of two models.

  • Mistral 7B released in October this year
  • The hybrid expert model released in December this year is called Mixtral 8x7B

A mis and a mix are essentially different

And what is the origin of this Mistral AI team?


​According to the introduction of Part 4 of this article " July Paper Review GPT Version 2: From Meta Nougat and GPT4 Review to Mistral and LongLora Llama "

  1. The Mistral AI team was co-founded in Paris in May this year by three former employees of DeepMind and Meta ( its CEO Arthur Mensch previously worked at DeepMind in Paris, and CTO Timothée Lacroix and chief scientist Guillaume Lample jointly participated in the LLaMA generation at Meta The research and development of OpenAI is very similar to when some employees of OpenAI left to form Anthropic )
  2. In October this year, they also released the first large base model, the Mistral 7B, which was once called the best 7B model because it outperformed the current best 13B parametric model (Llama) in all evaluation benchmarks. 2, the second generation of the benchmark), and surpasses the Llama 34B in reasoning, mathematics and code generation ( yes, here it is the benchmark of the first generation of Llama 34B )

1.1.1 Mixtral 8x7B is a sparse expert mixing network

Mixtral 8x7B is a pure decoder model

  1. It is a decoder-only model where the feedforward block picks from a set of 8 distinct groups of parameters
  2. At every layer, for every token, a router network chooses two of the groups ("experts") to process the token and add their outputs together . Many friends may not particularly care about these groups (the “experts”) to process the token and combine their output additively ,

    but if you taste it carefully, you will find that there is a lot of difference, that is: each token is composed of two Experts are responsible for completing it, and finally the entire sequence is completed by a series of "different pairs of experts", which will be discussed in detail below.

As shown in the figure below, each token passed into the model will be further directed to two experts (FFN) by the routing (Gating/Router) after passing through the attention layer and residual connection. Then the output of the expert will be weighted and aggregated, and then through the residual connection. Differential connection gets the output of the current layer

1.1.2  Why the total parameter size of Mixtral is 46.7B instead of 56B

Mixtral has a total of 46.7B parameters, but each token only uses 12.9B parameters. It, therefore, processes input and generates output at the same speed and for the same cost as a 12.9B model )

  1. That is, although the full name of the Mixtral model is "Mixtral-8x7B-v0.1" and it seems to have a parameter amount of "8x7B=56B", the actual parameter amount should be about 47B instead of 56B, because there are only The experts part (FFN) exists independently, and the remaining parts (Attention, etc.) are shared by all experts.
  2. It can be imagined as a "spindle" style. The data transmitted from the shared module to the expert module corresponds to the divergent part in the middle of the spindle, and the weighted aggregation of the output of the expert corresponds to the convergent part at the end of the spindle.

1.1.3 GQA mechanism adopted in Mixtral

Mixtral follows the GQA mechanism adopted in Mistral 7B. Compared with the traditional MHA (Multi-Head Attention), it mainly controls the K and V representation dimensions in the Attention mechanism, thereby reducing the number of parameters corresponding to K and V. In addition to GQA, there is also MQA (Multi-Query Attention). MQA can be considered a special case of GQA. The relevant dimensions are shown in the following table:

Q

K

V

MHA

hidden_dim

hidden_dim

hidden_dim

VERIFY

hidden_dim

hidden_dim/n

hidden_dim/n

QA

hidden_dim

1

1

where n is the ratio of K and V to the reduction of MHA parameters. Specifically, n is 4 in Mixtral

For more details about GQA, please see this article " One article explains all kinds of attention: from multi-head attention MHA to grouped query attention GQA, multi-query attention MQA "

1.1.4 Routing in Mixtral (Gating/Router)

Routing (Gating/Router) is essentially a linear layer. The input dimension is the hidden layer dimension hidden_dim, and the output dimension is the number of experts num_experts. The forward propagation process will be used to predict the score of each expert corresponding to the given token.

self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

As for the object of routing processing, it can be Sentence-Level, Token-Level or Task-Level.

  • Sentence-Level routes each sample separately.
  • Token-Level is to route each token in the sample separately.
  • Task-Level requires different experts to be clearly responsible for different tasks

Therefore, each sample is also routed separately, but the target expert it routes is clearly directed. For example, the data of a certain sample also provides "belonging task" information, through which the sample can be clearly directed to a full-time responsible person. In the expert corresponding to the task

Mixtral adopts Token-Level processing unit

  1. As for the first use of Token-Level MOE in NLP tasks, it can be traced back to " Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer " in 2017.

  2. This paper shows some interesting phenomena of Token-Level. By observing the statistical characteristics of the tokens that each expert is responsible for, different experts have indeed mastered some grammatical level understanding. When the indefinite article "a" is required to introduce directness in important verb phrases, When it is an object, there will be a special expert No. 752 responsible for outputting this "a"

1.2 Model performance: Match or surpass Llama 2 70B and GPT3.5

We compare Mixtral with the Llama 2 series and the GPT3.5 base model. Mixtral matches or outperforms Llama 2 70B as well as GPT3.5 in most benchmarks

Performance overview

The test in the figure below measures the trade-off between quality and inference budget. Mistral 7B and Mixtral 8x7B are more efficient than Llama 2

performance scale

The table below gives the detailed results for the above graph

Detailed benchmarking

To identify possible deficiencies to correct via fine-tuning/preference modeling, its performance on BBQ/BOLD was measured

BBQ BOLD Benchmark

Compared to the Llama 2, the Mixtral deviates less from the BBQ benchmark. Overall, Mixtral shows more positive sentiment than Llama 2 on BOLD

1.3 Instruction compliance model Mixtral 8x7B Instruct

Released together with Mixtral 8x7B is Mixtral 8x7B Instruct, which is optimized based on Mixtral 8x7B through supervised fine-tuning and direct preference optimization (DPO) to strictly follow the instructions.

For details about what DPO is and its principle, please refer to this article " Analysis of DPO Principle of Replacement of RLHF: From RLHF and Claude's RAILF to DPO and Zephyr "

On MT-Bench, it achieved a score of 8.30, making it the best open source model with performance comparable to GPT3.5

Part 2 Implementation details of Mixtral (MOE architecture): code interpretation

As A Xun said ( the base version of this part was provided by A Xun from the second project team of our large model project team, and I made a lot of supplements and explanations based on it  ), the above comparison about mixtral The counter-intuitive point is:

  • At every layer, for every token, a router network chooses two of these groups ( the " experts”) to process the token and combine their output additively
  • What it means is that if you don't understand it carefully, it is easy to mistakenly think that "an entire input sequence" is assigned to the TOP 2 experts. The result is that each token is assigned its own TOP 2 experts, and when you carefully dig out the mixtral code Later, you will find that this is really the case...

2.1 Forward propagation of MOE module: overall process

A single Mixtral layer can be roughly divided into an Attention module and a MOE module. The following focuses on the forward propagation process of the MOE module.

2.1.1 Obtain the top2 experts and their weights corresponding to each token

In order to ensure that everyone can understand the meaning of each line of code as quickly as possible, I broke it down into the following six steps based on Ah Xun's analysis, and added additional explanations to each step.

  1. Since the dimensions of hidden_states usually include batch size (batch_size), sequence length (sequence_length) and hidden layer dimension (hidden_dim), there is
    # 由Attention模块输出的hidden_states作为本部分的输入
    batch_size, sequence_length, hidden_dim = hidden_states.shape
  2. Reconstruct the shape of hidden_states into a two-dimensional tensor for processing it into a representation of each token
    # 转换成(bs*seq_len, hidden_dim),即token-level
    hidden_states = hidden_states.view(-1, hidden_dim)
  3. Generate routing logic (router_logits) through a gate mechanism, which is used to subsequently determine which experts (experts) should handle each token.
    # router_logits: (batch * sequence_length, n_experts)
    # (bs * seq_len, n_experts)
    router_logits = self.gate(hidden_states)
  4. Apply the softmax function to the routing logic of each token to calculate the processing weight of each expert for each token
    # 在token-level(dim=1)进行softmax,即每个token都各自进行n_experts分类的输出
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  5. Select the top_k most important experts and their weights for each token
    # routing_weights: (bs * seq_len, topk),是选取的experts对应的原始权重
    # selected_experts: (bs * seq_len, topk),是选取的experts的编号/索引号
    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  6. Normalize the expert weight of each selected token to ensure that the sum of the expert weights of each token is 1
    # 对原始权重重新归一化,使得所取出的experts权重加和等于1
    # routing_weights的具体样例见下文的【代码块A】
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

2.1.2 Pass each token into the corresponding expert model for forward propagation to obtain the output

  1. first
    # final_hidden_states: (bs * seq_len, hidden_dim)
    # 由全0张量初始化
    # final_hidden_states将用于存储各token对应expert的聚合结果
    final_hidden_states = torch.zeros(
        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
    )
  2. According to the given selected_experts as the index of the position of element 1, constructing a one-hot encoding with a vector length of num_experts
    is like 24 tokens, which need to be processed by 8 experts in pairs, then I construct a length of 8 for each token 0 1 coding, this coding represents 8 experts respectively.
    Therefore, which two experts are selected for each token, the corresponding coding bit becomes 1, otherwise it is 0. For

    example, if the token July selects two experts 3 and 7, then July corresponds The 0 1 encoding bits are: 0 0 1 0 0 0 1 0.
    For example, if the Edu token selects two experts 2 and 4, its 01 encoding is: 0 1 0 1 0 0 0 0
    and so on..
    # selected_experts.shape: (bs*seq_len, topk)
    # torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).shape: (bs*seq_len, topk, num_experts)
  3. Use relative trickery for forward propagation
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
    Specifically, the physical meaning of the following tensor
    torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0).shape: (num_experts, topk, bs*seq_len)
    is given by "Which topk experts are selected for each token" has become "When each expert exists as a ranking, which tokens need to be processed?" The
    advantage of this is that in subsequent cycles, only num_experts times need to be performed The result can be obtained by forward propagation without performing bs*seq_len times of forward propagation.

    In order to facilitate everyone to better understand the meaning of the above line of code, I specially drew a schematic diagram to speed up the understanding of
    \rightarrow  ABCDEFGHIJKLMNOPQRSTU VWXYZ
    \rightarrow , which is the token   1 2 that needs to be processed. 3 4 5 6 7 8, representing 8 experts
    ( as A Xun said, this changes the focus from "each token" to "each expert". Of course, in most cases, the number of tokens is far more than the figure below These 5 are much more than the number of experts. In short, such a conversion can ultimately save a lot of cycles  )

  4. So next we only need to loop num_experts times
    # 根据次序逐个取出expert模型
    for expert_idx in range(self.num_experts):
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])
    The above lines of code need to be explained carefully.
    Since expert_mask records each expert as each ranking exists, corresponding to which tokens need to be processed, so expert_mask[expert_idx].shape: (topk, bs*seq_len) , which is from expert_mask Get its corresponding one from, see [Code Block B] below for details.
    Therefore, the right-hand term in the equation in the last line of the above three lines: torch.where(expert_mask[expert_idx]) , it is to identify that the value of expert_mask[expert_idx] is 1 For the position index, see [Code Block C] below. As for

    : idx.shape: (bs * seq_len, ), it represents the index position where the element value (each column) in expert_mask[expert_idx] is 1
    and: top_x.shape: (bs * seq_len, ), represents the index position of the element value 1 in expert_mask[expert_idx] (each row).

    Continue to analyze the code after the for loop, as follows
        # 如果exert_mask[expert_idx]不存在元素为1的值则跳过
        if top_x.shape[0] == 0:
            continue
    
        # 全部token的隐向量hidden_states中取出当前expert对应token的隐向量
        # current_state.shape: (top_x_length, hidden_dim)
        current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
    
        # 将取出的token隐向量传入expert模型进行前向传播得到返回
        # current_hidden_states.shape: (top_x_length, hidden_dim)
        # expert_layer的正向过程详见下文的【代码块D】
        current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
    
        # 将当前expert的输出以加和的形式写入预先定义好的final_hidden_states张量中
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) 
  5. After the for loop ends, which is equivalent to after all experts have been processed, the maintained final_hidden_states will be converted from (bs * seq_len, hidden_dim) to (bs, seq_len, hidden_dim), and will be returned as the result of this batch run. For
    more details, see [Code Block E] below
    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

2.2 Detailed analysis of the five code blocks in MOE forward propagation: getting inside

2.2.1 Code block A: specific example of routing_weights

# 【代码块A】routing_weights
# 每行对应1个token,第0列为其对应排位第1的expert、第1列为其对应排位第2的expert,元素值为相应权重
[[0.5310, 0.4690],
 [0.5087, 0.4913],
 [0.5775, 0.4225],
 [0.5014, 0.4986],
 [0.5030, 0.4970],
 [0.5479, 0.4521],
 [0.5794, 0.4206],
 [0.5545, 0.4455],
 [0.5310, 0.4690],
 [0.5294, 0.4706],
 [0.5375, 0.4625],
 [0.5417, 0.4583],
 [0.5014, 0.4986],
 [0.5239, 0.4761],
 [0.5817, 0.4183],
 [0.5126, 0.4874]]

2.2.2 Code block B: expert_mask[expert_idx]

Because there is: expert_mask record, when each expert exists as each ranking, corresponding to which tokens need to be processed, so
there is: expert_mask[expert_idx] From the expert_mask, take out the expert_idxth expert, which tokens,
\rightarrow  the 0th row, and the 0th row will be processed by this expert as the ranking. The first line of the token processed when 1 exists is
\rightarrow  the token processed when the expert exists as ranked 2.

# 【代码块B】expert_mask[expert_idx]
# 下述两行例子的物理含义为:
# 第一行是“该expert作为排位1的exert存在时,需要处理第9个token;
# 第二行是“该expert作为排位2的expert存在时,需要处理第10、11个token”
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]]

2.2.3 Code block C: idx, top_x = torch.where(expert_mask[expert_idx])

# 【代码块C】idx, top_x = torch.where(expert_mask[expert_idx])
# 以上述expert_mask[expert_idx]样例为例,对应的torch.where(expert_mask[expert_idx])结果如下
idx: [0, 1, 1]
top_x: [9, 10, 11]

idx corresponds to the row index, and top_x corresponds to the column index. For example, in the tensor expert_mask[expert_idx], the index of element 1 is (0, 9), (1, 10), (1, 11).
From the physical meaning, top_x actually The above corresponds to the "token index related to the current expert". The 9th, 10th, and 11th tokens are "routed" to the expert currently concerned. Through top_x, you can get the "input that needs to be passed in to the expert". That is, the hidden vectors corresponding to the 9th, 10th, and 11th tokens

  • Therefore top_x will be used as an index to extract the hidden vector of the corresponding token from the hidden_states of all tokens.
  • The combination of idx and top_x will also be used to extract the corresponding weight from the expert weight tensor routing_weights.

And through the combination of row index and column index routing_weights

2.2.4 Code block D: forward propagation inside expert

# 【代码块D】expert内部的前向传播
def forward(self, hidden_states, routing_weights):
    current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
    current_hidden_states = self.w2(current_hidden_states)
    return routing_weights * current_hidden_states

Its input parameters include not only the hidden vector of the corresponding token of the expert, but also the weight corresponding to the expert. The whole is an FFN based on swiGLU activation.

Finally, the output of FFN is weighted to obtain the actual output of the expert, so the weighting process is already performed inside the expert.

2.2.5 Code block E: final_hidden_states

  1. Initially final_hidden_states is an all-0 tensor
    # 查看与当前expert有关的final_hidden_states部分,即final_hidden_states[top_x]
    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]]
  2. After using the .index_add_ function, the specified value (current_hidden_states) is added to the specified position (top_x)
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  3. Look again at the final_hidden_states part related to the current expert, that is
    [[ 0.0938,  0.0509, -0.0689,  ..., -0.0182, -0.0246,  0.0468],
     [ 0.1246,  0.0642,  0.0015,  ...,  0.0100, -0.0110,  0.0219],
     [ 0.0478, -0.0192,  0.0139,  ..., -0.0039, -0.0197,  0.0475]]

Part 3 The development history and more practical details of the mixed expert model MOE

// To be updated

References and Recommended Reading

  1. A magnet link sweeps the AI ​​​​circle, and 87GB seeds directly open source the 8x7B MoE model
  2. Mistral AI's introduction to Mixtral of experts: Mixtral of experts | Mistral AI | Open source models
  3. Open source large models surpass GPT-3.5! Explosive MoE actual measurement results released
  4. https://github.com/nateraw/replicate-examples/tree/main/mixtral
  5. Pre-trained large model: Baidu UFO (Unified Feature Optimization)
  6. Collection of 4 three papers recommended by students and friends wstart
    LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment
    MegaBlocks: Efficient Sparse Training with Mixture-of-Experts
    Weak-to-Strong Generalization: Eliciting Strong Capabilities With Weak Supervision
  7. ..

Guess you like

Origin blog.csdn.net/v_JULY_v/article/details/135176583
Recommended