Model Aggregation in Federated Learning

Table of contents

Model Aggregation in Federated Learning

1. client-server algorithm

2. Fully decentralized algorithm


Model Aggregation in Federated Learning

In the context of federated learning, multi-task learning is introduced. The method used is to make the distribution of training data of each client/task node different, so that each task node can learn different models, and each task node and the global (global ) models are integrated by multiple component models. The most critical and core part of this paper is to aggregate/communicate the models learned by each task node. According to the different aggregation methods of the models, the algorithms adopted by the models can be divided into client-server methods, and fully decentralized (completely decentralized) methods. )Methods

Because there are many kinds of task aggregators (Aggregators) to be implemented, the measures to be taken are to implement the Aggregator abstract base class first, implement some general methods, and specify the interface of the abstract methods, and then the specific task aggregation class inherits the abstract base class, and then Do concrete realization.

Let's first look at the abstract base class of task aggregator (Aggregator)

class Aggregator(ABC):
    r"""Aggregator的基类. `Aggregator`规定了client之间的通信"""
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=False,
            test_clients=None,
            verbose=0,
            seed=None,
            *args,
            **kwargs
    ):

        rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
        self.rng = random.Random(rng_seed) # 随机数生成器
        self.np_rng = np.random.default_rng(rng_seed) # numpy随机数生成器

        if test_clients is None:
            test_clients = []

        self.clients = clients #  List[Client]
        self.test_clients = test_clients #  List[Client]

        self.global_learners_ensemble = global_learners_ensemble # List[Learner]
        self.device = self.global_learners_ensemble.device


        self.log_freq = log_freq
        self.verbose = verbose
        # verbose: 调整输出打印的冗余度(verbosity), 
        # `0` 表示quiet(无任何打印输出), `1` 显示日志, `2` 显示所有局部日志; 默认是 `0`
        self.global_train_logger = global_train_logger
        self.global_test_logger = global_test_logger

        self.model_dim = self.global_learners_ensemble.model_dim # #模型特征维度

        self.n_clients = len(clients)
        self.n_test_clients = len(test_clients)
        self.n_learners = len(self.global_learners_ensemble)

        # 存储为每个client分配的权重(权重为0-1之间的小数)
        self.clients_weights =\
            torch.tensor(
                [client.n_train_samples for client in self.clients],
                dtype=torch.float32
            )
        self.clients_weights = self.clients_weights / self.clients_weights.sum()

        self.sampling_rate = sampling_rate  #  clients在每一轮使用的比例,默认为`1.`
        self.sample_with_replacement = sample_with_replacement #对client进行采用是可重复还是无重复的,with_replacement=True表示可重复的,否则是不可重复的

        # 每轮迭代需要使用到的client个数
        self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))

        # 采样得到的client列表
        self.sampled_clients = list()

        # 记载当前的迭代通信轮数
        self.c_round = 0 
        self.write_logs()

    @abstractmethod
    def mix(self): 
        """
        该方法用于完成各client之间的权重参数与通信操作
        """
        pass

    @abstractmethod
    def update_clients(self): 
        """
        该方法用于将所有全局分量模型拷贝到各个client,相当于boardcast操作
        """
        pass

    def update_test_clients(self):
        """
        将全局(gobal)的所有分量模型都拷贝到各个client上
        """

    def write_logs(self):
        """
        对全局(global)的train和test数据集的loss和acc做记录
        需要对所有client的所有样本做累加,然后除以所有client的样本总数做平均。
        """

    def save_state(self, dir_path):
        """
        保存aggregator的模型state,。例如, `global_learners_ensemble`中每个分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每个client的 `learners_weights` (注意,这个权重不是模型内部的参数,而是进行继承的时候对各个分量模型赋予的权重,包含train和test两部分,以一个大小为n_clients(n_test_clients)× n_learners的numpy数组的格式,即`.npy` 文件)。
        """

    def load_state(self, dir_path):
        """
        加载aggregator的模型state,即save_state方法里保存的那些
        """

    def sample_clients(self):
        """
        对clients进行采样,
        如果self.sample_with_replacement为True,则为可重复采样,
        否则,则为不可重复采用。
        最终得到一个clients子集列表并赋予self.sampled_clients
        """

1. client-server algorithm

This method of communication/aggregation is also called a centralized method, because this method gathers the weight data of all clients to the server node at the end of each iteration. The pseudo code of the optimization iteration part of this method is shown as follows:

 

In terms of specific code implementation, the Aggregator design of this method is as follows:

class CentralizedAggregator(Aggregator):
    r""" 标准的中心化Aggreagator
    所有clients在每一轮迭代末和average client完全同步.
    """
    def mix(self):
        self.sample_clients()

        # 对self.sampled_clients中每个client的参数进行优化
        for client in self.sampled_clients:
            # 相当于伪代码第11行调用的LocalSolver函数
            client.step()

        # 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
        # 相当于伪代码第13行
        for learner_id, learner in enumerate(self.global_learners_ensemble):
            # 获取所有client中对应learner_id的分量模型
            learners = [client.learners_ensemble[learner_id] for client in self.clients]
            # global模型的分量模型为所有client对应分量模型取平均,相当于伪代码第14行
            average_learners(learners, learner, weights=self.clients_weights)

        # 将更新后的模型赋予所有clients,相当于伪代码第5行的boardcast操作
        self.update_clients()

        # 通信轮数+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

    def update_clients(self):
        """
        此函数负责将所有全局分量模型拷贝到各个client,相当于伪代码中第5行的boardcast操作
        """
        for client in self.clients:
            for learner_id, learner in enumerate(client.learners_ensemble):
                copy_model(learner.model, self.global_learners_ensemble[learner_id].model)

                if callable(getattr(learner.optimizer, "set_initial_params", None)):
                    learner.optimizer.set_initial_params(
                        self.global_learners_ensemble[learner_id].model.parameters()
                    )

2. Fully decentralized algorithm

This method is called decentralized because it does not require all client weight data to be gathered to a specific server node in each round of iteration, but only needs to complete the communication between each node and its neighbors (parameter share) is fine. The pseudo-code of the optimization iteration part of this method is shown as follows:

implemented in the specific code implementation, the Aggregator design of this method is as follows:

 

class DecentralizedAggregator(Aggregator):
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            mixing_matrix,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=True,
            test_clients=None,
            verbose=0,
            seed=None):

        super(DecentralizedAggregator, self).__init__(
            clients=clients,
            global_learners_ensemble=global_learners_ensemble,
            log_freq=log_freq,
            global_train_logger=global_train_logger,
            global_test_logger=global_test_logger,
            sampling_rate=sampling_rate,
            sample_with_replacement=sample_with_replacement,
            test_clients=test_clients,
            verbose=verbose,
            seed=seed
        )

        self.mixing_matrix = mixing_matrix
        assert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"

    def update_clients(self):
        pass

    def mix(self):
        
        # 对各clients的模型参数进行优化
        for client in self.clients:
            client.step()

        # 存储每个模型各参数混合的权重
        # 行对应不同的client,列对应单个模型中不同的参数
        # (注意:每个分量有独立的mixing_matrix)
        mixing_matrix = torch.tensor(
            self.mixing_matrix.copy(),
            dtype=torch.float32,
            device=self.device
        )

        # 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
        # 相当于伪代码第14行
        for learner_id, global_learner in enumerate(self.global_learners_ensemble):
            # 用于将指定learner_id的各client的模型state读出暂存
            state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]

            # 遍历global模型中的各参数, key对应模型中参数的名称
            for key, param in global_learner.model.state_dict().items():
                shape_ = param.shape
                models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)

                for ii, sd in enumerate(state_dicts):
                    # models_params的第ii个下标存储的是第ii个client的(名为key的)参数
                    models_params[ii] = sd[key].view(1, -1) 

                # models_params的每一行是一个client的参数
                # @符号表示矩阵乘/矩阵向量乘
                # 故这里表示每个client参数是其他所有client参数的混合
                models_params = mixing_matrix @ models_params

                for ii, sd in enumerate(state_dicts):
                    # 将第ii个client的(名为key的)参数存入state_dicts中对应位置
                    sd[key] = models_params[ii].view(shape_)

            # 将更新好的参数从state_dicts存入各client节点的模型中
            for client_id, client in enumerate(self.clients):
                client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])

        # 通信轮数+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

Guess you like

Origin blog.csdn.net/qq_38998213/article/details/131444546