Implementing local simulation of horizontal federated learning using Python

Use Python to simulate multiple clients locally, and then the server manages them uniformly for federated learning. The client trains the model locally with its own data. The server aggregates the training results, updates the model and distributes it to the client, and the client continues training.

Table of contents

I. Introduction

1. What is federated learning?

2. Privacy issues in federated learning

2. Environmental preparation

3. Specific implementation

1.Write configuration file

2. Obtain the training data set

3. Write server-side code

4. Write client code

5.Write main function

6. Write model files

4. Conduct testing

5. Comparison of the effects of federated learning and centralized training

6. Summary


I. Introduction

1. What is federated learning?

Federated learning is a machine learning concept. The concept is that the server first distributes a unified model to the client, and then the client uses local data for training, and then updates the model and sends it back to the server. After the server receives the new models from all parties, it performs The corresponding calculation is then performed to update the global model, and then the updated model is divided and continued training until the number of times is reached or convergence is reached, and finally a model jointly trained by multiple parties is obtained. The core concept of federated learning is that "the model does not understand the movement of data, and the data is available but invisible."

(ps: Only a brief introduction is given here. For a more detailed introduction to federated learning, please refer to the following article)

https://blog.csdn.net/cao812755156/article/details/89598410?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164941377116780357226006%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=164941377116780357226006&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~top_positive~default-1-89598410.nonecase&utm_term=%E8%81%94%E9%82%A6%E5%AD%A6%E4%B9%A0&spm=1018.2226.3001.4450https://blog.csdn.net/cao812755156/article/details/89598410?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164941377116780357226006%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=164941377116780357226006&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~top_positive~default-1-89598410.nonecase&utm_term=%E8%81%94%E9%82%A6%E5%AD%A6%E4%B9%A0&spm=1018.2226.3001.4450

2. Privacy issues in federated learning

The essence of federated learning is still machine learning, but the traditional centralized learning of data is changed to each client learning by itself, and only the parameters are exchanged. This can ensure the security and privacy of the data, but this is only the most basic. In view of the various This kind of attack on federated learning should also be combined with other technologies to ensure the security of the learning process.

(ps: I won’t do too much analysis on the data security issues in federated learning here. If you are interested, you can refer to the following paper, which is very well written and easy to understand)

http://www.jos.org.cn/jos/article/abstract/6446http://www.jos.org.cn/jos/article/abstract/6446


2. Environmental preparation

This experiment is implemented in python and uses the machine learning library PyTorch.

  • anaconda、python、PyTorch
  • Compiler using Pycharm
  • Dataset: cifar10
  • Model: ResNet-18

Basic process:

  1. The server generates an initialization model according to the configuration file, and the client cuts the data set horizontally without overlapping according to its own ID.
  2. The server sends the global model to the client
  3. The client receives the global model (from the server) and calculates the local parameter difference through multiple local iterations and returns it to the server
  4. The server aggregates the differences between each client to update the model, and then evaluates the current model performance.
  5. If the performance does not meet the standard, repeat process 2, otherwise end

3. Specific implementation

1.Write configuration file

Create a utils folder under the project folder and create the configuration file conf.json in it. The data in it can be changed as needed. (Because json files do not allow comments, each value is assigned twice, the first time is used as a comment, and the second time is the real value)

{
    "model_name" : "模型名称",
	"model_name" : "resnet18",

    "no_models" : "客户端总数量",
	"no_models" : 5,

    "type" : "数据集信息",
	"type" : "cifar",

	"global_epochs" : "全局迭代次数",
	"global_epochs" : 5,

	"local_epochs" : "本地迭代次数",
	"local_epochs" : 2,

	"k" : "每一轮选用k个客户端参与训练",
	"k" : 3,

	"batch_size" : "本地训练每一轮的样本数",
	"batch_size" : 32,

    "notes" : "本地训练的超参数设置",
	"lr" : 0.001,
	"momentum" : 0.0001,
	"lambda" : 0.1
}

2. Obtain the training data set

Create the datasets.py file under the project folder.

from torchvision import datasets, transforms

# 获取数据集
def get_dataset(dir, name):

    if name == 'mnist':
        # root: 数据路径
        # train参数表示是否是训练集或者测试集
        # download=true表示从互联网上下载数据集并把数据集放在root路径中
        # transform:图像类型的转换
        train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
        eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())

    elif name == 'cifar':
        # 设置两个转换格式
        # transforms.Compose 是将多个transform组合起来使用(由transform构成的列表)
        transform_train = transforms.Compose([
            # transforms.RandomCrop: 切割中心点的位置随机选取
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
        eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)

    return train_dataset, eval_dataset

3. Write server-side code

Create the server.py file in the project folder. The main function of the server is to aggregate, evaluate, and distribute models, including constructors, aggregation functions (using the FedAvg algorithm), and evaluation functions.

import models
import torch

# 服务器类
class Server(object):
    # 定义构造函数
    def __init__(self, conf, eval_dataset):
        # 导入配置文件
        self.conf = conf
        # 根据配置文件获取模型
        self.global_model = models.get_model(self.conf["model_name"])
        # 生成测试集合加载器
        self.eval_loader = torch.utils.data.DataLoader(
          eval_dataset,
          # 根据配置文件设置单个批次大小(32)
          batch_size=self.conf["batch_size"],
          # 打乱数据集
          shuffle=True
        )

    # 模型聚合函数
    # weight_accumulator 存储了每个客户端上传参数的变化值
    def model_aggregate(self, weight_accumulator):
        # 遍历服务器的全局模型
        for name, data in self.global_model.state_dict().items():
            # 更新每一次乘以配置文件中的学习率
            update_per_layer = weight_accumulator[name] * self.conf["lambda"]
            # 累加
            if data.type() != update_per_layer.type():
                # 因为update_per_layer的type是floatTensor,所以将其转换为模型的LongTensor(损失精度)
                data.add_(update_per_layer.to(torch.int64))
            else:
                data.add_(update_per_layer)

    # 模型评估函数
    def model_eval(self):
        # 开启模型评估模式
        self.global_model.eval()
        total_loss = 0.0
        correct = 0
        dataset_size = 0
        # 遍历评估数据集合
        for batch_id, batch in enumerate(self.eval_loader):
            data, target = batch
            # 获取所有样本总量大小
            dataset_size += data.size()[0]
            # 如果可以的话存储到gpu
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            # 加载到模型中训练
            output = self.global_model(data)
            # 聚合所有损失 cross_entropy 交叉熵函数计算损失
            total_loss += torch.nn.functional.cross_entropy(
              output,
              target,
              reduction='sum'
            ).item()
            # 获取最大的对数概率的索引值,即在所有预测结果中选择可能性最大的作为最终结果
            pred = output.data.max(1)[1]
            # 统计预测结果与真实标签的匹配个数
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        # 计算准确率
        acc = 100.0 * (float(correct) / float(dataset_size))
        # 计算损失值
        total_l = total_loss / dataset_size

        return acc, total_l

4. Write client code

Create the client.py file in the project folder. The main function of the client is to accept the global model from the server, use local data to train the model and return the difference, including the constructor and local training function.

import models
import torch

# 客户端类
class Client(object):
    #构造函数
    def __init__(self, conf, model, train_dataset, id=-1):
        # 读取配置文件
        self.conf = conf
        # 根据配置文件获取客户端本地模型(一般由服务器传输)
        self.local_model = models.get_model(self.conf["model_name"])
        # 客户端ID
        self.client_id = id
        # 客户端本地数据集
        self.train_dataset = train_dataset
        # 按ID对数据集集合进行拆分
        all_range = list(range(len(self.train_dataset)))
        data_len = int(len(self.train_dataset) / self.conf['no_models'])
        train_indices = all_range[id * data_len: (id + 1) * data_len]
        # 生成数据加载器
        self.train_loader = torch.utils.data.DataLoader(
            # 指定父集合
            self.train_dataset,
            # 每个batch加载多少样本
            batch_size=conf["batch_size"],
            # 指定子集合
            # sampler定义从数据集中提取样本的策略
            sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices)
        )

    # 模型本地训练函数
    def local_train(self, model):
        # 客户端获取服务器的模型,然后通过部分本地数据集进行训练
        for name, param in model.state_dict().items():
            # 用服务器下发的全局模型覆盖本地模型
            self.local_model.state_dict()[name].copy_(param.clone())
        # 定义最优化函数器用户本地模型训练
        optimizer = torch.optim.SGD(
            self.local_model.parameters(),
            lr=self.conf['lr'],
            momentum=self.conf['momentum']
        )
        # 本地训练模型
        # 设置开启模型训练
        self.local_model.train()
        # 开始训练模型
        for e in range(self.conf["local_epochs"]):
            for batch_id, batch in enumerate(self.train_loader):
                data, target = batch
                # 如果可以的话加载到gpu
                if torch.cuda.is_available():
                    data = data.cuda()
                    target = target.cuda()
                # 梯度初始化为0
                optimizer.zero_grad()
                # 训练预测
                output = self.local_model(data)
                # 计算损失函数cross_entropy交叉熵误差
                loss = torch.nn.functional.cross_entropy(output, target)
                # 反向传播
                loss.backward()
                # 更新参数
                optimizer.step()
            print("Epoch %d done." % e)
        # 创建差值字典(结构与模型参数同规格),用于记录差值
        diff = dict()
        for name, data in self.local_model.state_dict().items():
            # 计算训练后与训练前的差值
            diff[name] = (data - model.state_dict()[name])
            print("Client %d local train done" % self.client_id)
        # 客户端返回差值
        return diff

5.Write main function

Create the main.py file in the project folder to integrate the code.

import argparse
import json
import random

import datasets
from client import *
from server import *

if __name__ == '__main__':
    # 设置命令行程序
    parser = argparse.ArgumentParser(description='Federated Learning')
    parser.add_argument('-c', '--conf', dest='conf')
    # 获取所有参数
    args = parser.parse_args()
    # 读取配置文件,指定编码格式为utf-8
    with open(args.conf, 'r', encoding='utf-8') as f:
        conf = json.load(f)
    # 获取数据集,加载描述信息
    train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
    # 启动服务器
    server = Server(conf, eval_datasets)
    # 定义客户端列表
    clients = []
    # 创建10个客户端到列表中
    for c in range(conf["no_models"]):
        clients.append(Client(conf, server.global_model, train_datasets, c))

    print("\n\n")
    # 全局模型训练
    for e in range(conf["global_epochs"]):
        print("Global Epoch %d" % e)
        # 每次训练从clients列表中随机抽取k个进行训练
        candidates = random.sample(clients, conf["k"])
        print("select clients is: ")
        for c in candidates:
            print(c.client_id)
        # 累计权重
        weight_accumulator = {}
        # 初始化空模型参数weight_accumulator
        for name, params in server.global_model.state_dict().items():
            # 生成一个和参数矩阵大小相同的0矩阵
            weight_accumulator[name] = torch.zeros_like(params)
        # 遍历选中的客户端,每个客户端本地进行训练
        for c in candidates:
            diff = c.local_train(server.global_model)
            # 根据客户端返回的参数差值字典更新总体权重
            for name, params in server.global_model.state_dict().items():
                weight_accumulator[name].add_(diff[name])
        # 模型参数聚合
        server.model_aggregate(weight_accumulator)
        # 模型评估
        acc, loss = server.model_eval()

        print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))

6. Write model files

Create the models.py file in the project folder to define various machine learning models for use.

import torch
from torchvision import models

# 各种机器学习模型
def get_model(name="vgg16", pretrained=True):
    if name == "resnet18":
        model = models.resnet18(pretrained=pretrained)
    elif name == "resnet50":
        model = models.resnet50(pretrained=pretrained)
    elif name == "densenet121":
        model = models.densenet121(pretrained=pretrained)
    elif name == "alexnet":
        model = models.alexnet(pretrained=pretrained)
    elif name == "vgg16":
        model = models.vgg16(pretrained=pretrained)
    elif name == "vgg19":
        model = models.vgg19(pretrained=pretrained)
    elif name == "inception_v3":
        model = models.inception_v3(pretrained=pretrained)
    elif name == "googlenet":
        model = models.googlenet(pretrained=pretrained)

    if torch.cuda.is_available():
        return model.cuda()
    else:
        return model

4. Conduct testing

Entire project structure:

Enter the project directory and use the command line to run the following command:

python main.py -c ./utils/conf.json

First, the data set will be downloaded, and the data folder will appear in the project directory, which will contain the data set to be used. Then the client will be randomly selected for training, and it will stop after completing the specified number of rounds of training.

5. Comparison of the effects of federated learning and centralized training

Federated training: There are 10 client devices in total, 5 are selected to participate in training in each round, each local training iteration is 3 times, and the global iteration number is 20 times.

Centralized training: Change the number of client devices to 1 and select one device for training in each round. The effect of centralized training can be achieved by modifying the configuration.

 

 The single-point training in the figure is only the result of model training using local data under a certain client.

  • We see that the model effect of single-point training (blue bar) is significantly lower than the effect of federated training (green bar and red bar), which also shows that the data cannot be learned well only through the data of a single client. Global distribution characteristics, the generalization ability of the model is poor.

  • In addition, depending on the number of clients (k value) participating in federated training in each round, the performance will also be different. The larger the k value, the more clients participating in each round of training, and the performance will be better. Good, but each round will also take a relatively long time to complete.


6. Summary

At present, federated learning has been widely used, and various large companies have developed their own federated learning frameworks, such as WeBank's FATE, Google's TensorFlow, OpenMind's PySyft, Baidu's PaddleFL, Byte's FedLearner, etc. A very good framework. This article only simulates the client and server locally for training. Training on multiple machines requires additional learning.

Guess you like

Origin blog.csdn.net/SAGIRIsagiri/article/details/124048502