"Horizontal split learning" of split learning in "Little Lesson of Hidden Words"

picture

I. Introduction

Split learning is a distributed algorithm first proposed by MIT in 2018. This article combines the relevant English literature in this field to introduce the basic method of horizontal split learning. At the same time, it will compare the model efficiency and accuracy of the split model with the centralized model and federated model under different conditions. As one of the mainstream privacy computing learning paradigms, split learning is also widely used to build privacy-preserving machine learning algorithms.

2. Basic method

1.1 Core idea

Split learning divides the NN model into two parts. The client uses local data to calculate the underlying model, obtains the hidden layer and transmits it to the server, and the server continues to calculate the upper model, as shown in Figure 1 [1].

picture

Figure 1 Schematic diagram of split learning

For split learning methods in the scenario of client data horizontal segmentation, there are mainly three types: Centralized split learning, P2P split learning and U-shape split learning.

1.2 Centralized split learning

picture

Figure 2 Centralized split learning model

(1) Algorithm

As shown in Figure 2 [2], Alice is the client and Bob is the server. The client and server models are initialized first.

  • client i obtains the encrypted model parameters of the client from the server, decrypts it, and updates the client model.

  • client i performs forward propagation, calculates the hidden layer, and sends the hidden layer h and the real label y to the server

  • The server gets the hidden layer h and y of client i, continues forward propagation, gets the predicted value of label y_pred, and then gets Loss(y, y_pred).

  • The server performs the backward propagation of the model, updates the model on the server side, and then obtains the gradient G of Loss on the hidden layer, and transmits G to the client.

  • The client uses the gradient G to continue backward propagation to update the client's local model, and the client encrypts the local model and transmits it to the server.

  • The rest of the clients participating in the training proceed to the above steps in turn.

(2) Features

  • The client is updated asynchronously and cannot be updated synchronously;

  • The client needs to obtain the client model of the dense state from the server before each training;

  • The server gets the label of the sample and the secret client model (there is a risk of privacy leakage).

1.3 P2P split learning

  • picture

Figure 3 Peer to peer split learning

(1) Algorithm

As shown in Figure 3 [2], client i performs forward propagation, calculates the hidden layer, and transmits the hidden layer h and the real label y to the server.

  • The server gets the hidden layer h and y of client i, continues forward propagation, gets the predicted value of label y_pred, and then gets Loss(y, y_pred);

  • The server performs the backward propagation of the model, updates the model on the server side, and then obtains the gradient G of Loss on the hidden layer, and passes G to the client;

  • The client uses the gradient G to continue backward propagation to update the local model of the client, and the client passes the local model to the next client;

  • The next client performs the above steps in sequence.

(2) Features

  • The client performs training updates in sequence.

  • The server gets the label of the sample.

  • The client needs to obtain the latest client model from the previous client before each training (client disconnection problem).

1.4 U-shape split learning

picture

Figure 4 U-shape split learning

(1) Algorithm

As shown in Figure 4 [1], the model is divided into three parts in turn: submodel-1, submodel-2 (most calculations), submodel-3 (loss computing), where submodel-1 and submodel-3 are performed on the client side, submodel-2 is performed on the server side. Take U-shape centralized split learning as an example:

  • client i obtains the model parameters of the encrypted submodel-1 and submodel-3 of the client from the server, decrypts it, and updates the local model of the client.

  • Client i performs forward propagation, calculates the hidden layer, and passes the hidden layer h1 to the server.

  • The server gets the hidden layer h of client i, continues the submodel-2 forward propagation, gets the hidden layer h2, and passes it to the client.

  • The client gets h1, continues the forward propagation of submodel-3, gets y_pred, and calculates the loss by combining the real label y of the client.

  • The client and server carry out the backward propagation of the model and update the model.

  • The client sends the encrypted local model submodel-1 and submodel-3 to the server.

  • The rest of the clients participating in the training will carry out the above steps in sequence.

(2) Features

  • Compared with the previous two models, the server cannot get the label of the sample.

3. Experimental results

3.1 Split learning vs stand-alone model

In the paper [2], the Accuracy of split learning and stand-alone model is compared, in which there are 10 clients in split learning, and the experimental results shown in the following table are obtained.

picture

Experimental conclusion: split learning can align the Accuracy of the stand-alone model [2].

3.2 Split Learning VS Federated Learning

Paper [2] compares the performance of split learning and federated learning under the same client-side flops and communication cost.

Paper [3] compares the performance of split learning and federated learning under multi-clients conditions and Non-IID data distribution.

(1)Performance with the same client-side flops

picture

Conclusion: In the case of the same calculation amount, the convergence speed and Accuracy of split learning are better than federated learning and Large scale SGD.

(2)Performance with the same communication cost

picture

Conclusion: In the case of the same traffic, the convergence speed and accuracy of split learning are better than federated learning and Large scale SGD.

(3)Performance with the different clients’ number

picture

Conclusion: When the number of clients increases, the performance of the model fluctuates significantly.

(4)Performance in the Non-IID setting

picture

Conclusion: The performance of split learning under Non-IID is worse than that of federated learning, and even does not converge.

4. Reference

【1】Thapa C, Chamikara M A P, Camtepe S A. Advancements of federated learning towards privacy preservation: from federated learning to split learning[M]//Federated Learning Systems. Springer, Cham, 2021: 79-109.

【2】Gupta O, Raskar R. Distributed learning of deep neural network over multiple agents[J]. Journal of Network and Computer Applications, 2018, 116: 1-8.

【3】Gao Y, Kim M, Abuadbba S, et al. End-to-end evaluation of federated learning and split learning for Internet of Things[J]. arXiv preprint arXiv:2003.13376, 2020.

Guess you like

Origin blog.csdn.net/m0_69580723/article/details/132343102