Adaptive Personalized Federated Learning 论文解读+代码解析

论文地址点这里

一. 介绍

联邦学习强调确保本地隐私情况下,对多个客户端进行训练,客户端之间不交换数据而交换参数来进行通信。目的是聚合成一个全局的模型,使得这个模型再各个客户端上读能取得较好的成果。联邦学习中FedAvg方法最为广泛,但由于本地数据分片之间的固有多样性和数据再客户端的高度非iid(独立同分布),FedAvg对超参数非常敏感,不能从良好的手链保证中获益。因此在设备异质性存在的情况下,全局模型不能很好的概括每个客户单独的本地数据。
随着客户端数据的多样性增加,全局模型和个性化(客户端)模型的误差将会越来越大,好的全局模型回导致一个差的本地客户端的模型。
在这项工作中,作者提出了一个新的联邦学习框架,该框架优化了所有客户急的性能。减少泛化误差依赖于局部数据的分布特征。因此,该模型目标为倾向于学习一种混合了全球模式和本地模式的个性化模式。但难点在于如何确保局部数据是适合所有客户的全局模型。

二. 相关工作

联邦学习的主要目标为学习一个全局模型,这个全局模型对与尚未看到的数据足够好,并且能快速收敛到局部最优,这一点和元学习有一些相似之处。但尽管存在这种相似性,元学习方法主要是常识学习多个模型,针对每个新任务进行个性化学习,而在大多数联邦学习中,更关注单个全局模型。而全局模型和本地模型的差异性就是个性化的重要表现。联邦学习中个性化主要有三大类,本地微调,多任务学习和情景化。
本地微调(Local fine tuning): 本地微调即每个客户端接收到一个全局模型,并使用自己的局部数据和几个梯度下降步骤对其进行调优,这种方法主要结合了元学习。
多任务学习(multi_task learning): 对个性化问题的另一种观点是视为多任务学习问题。这种设置下对每个客户端的优化可以看做是一个新的任务。
情景化(Contextualization): 个性化联邦学习中的一个重要应用是在不同情境下使用模型。我们需要在不同的环境下对一个客户端进行个性化的模型。
通过模型混合进行个性化(Personalization via mixing models): 通过混合全局和局部的模型引入不同的个性化方法来进行联邦学习。基于此,有三种不同的个性化方法,即客户聚类、数据插值和模型插值。而前两种对数据隐私性造成破坏,只有第三种是较为合理的模式。

三. 个性化联邦学习

3.1 定义:

D i : D_i : Di第i个客户端上数据集(有标签)
D ˉ = ( 1 / n ) ∑ i = 1 n D i \bar{D} = (1/n)\sum_{i=1}^{n}D_i Dˉ=(1/n)i=1nDi:所有客户端的平均分布
L D i ( h ) = E ( x , y ) ∈ D i [ l ( h ( x ) , y ) ] : \mathcal{L}_{D_i}(h) = \mathbb{E}_{(x,y)\in D_i}[\mathcal{l(h(x),y)}]: LDi(h)=E(x,y)Di[l(h(x),y)]:在客户端i上的真实风险。
L ^ D i ( h ) : \widehat{\mathcal{L}}_{D_i}(h): L Di(h):在客户端i上对于h的经验风险

3.2 个性化模型

在一个标准的联邦学习场景中,目的是为所有设备合作学习一个全局模型。同时各个客户端存在着本地模型,在自适应个性化联邦学习中,目标是找到全局模型和局部模型的最优组合,以实现更好的针对客户的模型。在这种设置中,每个用户训练一个局部模型,同时合并部分全局模型,并使用一些混合权重,数学表达如下:
h α i = α i h ^ i ∗   +   ( 1 − α i ) h ˉ ∗ h_{\alpha_i} = \alpha_i \widehat{h}_i^*\ +\ (1 -\alpha_i )\bar{h}^* hαi=αih i + (1αi)hˉ
其中 h ˉ ∗ = a r g min ⁡ h ∈ H L ^ D ˉ ( h ) \bar{h}^* = arg\min_{h\in\mathcal{H}}\widehat{\mathcal{L}}_{\bar{D}}(h) hˉ=argminhHL Dˉ(h)为全局的经验优化最小,
h ^ i ∗ = a r g min ⁡ h ∈ H L ^ D ˉ ( α i h + ( 1 − α i ) h ˉ ∗ ) \widehat{h}_i^* = arg\min_{ {h\in\mathcal{H}}}\widehat{\mathcal{L}}_{\bar{D}}(\alpha_ih+(1-\alpha_i)\bar{h}^*) h i=argminhHL Dˉ(αih+(1αi)hˉ)是一个在第i个客户端上取得最小损失的混合模型。
(这里我解释一下,就是说我们的模型由两部分组成,一个是全局的模型,另一个是客户端的模型,至于为什么客户端的模型又是由一个混合组成呢?这里考虑成多轮训练即可,假设t-1轮全局模型为w,本地模型为v,然后我们融合成混合模型为h=w+v。在t轮的时候,全局模型为新的w,而本地模型则是继承t-1轮的混合模型h,所以对应的v可以代指为本地模型)

3.3 APFL算法

就像传统的联邦学习意义,服务器需要解决目标如下:
min ⁡ w ∈ R d [ F ( w ) = 1 n ∑ i = 1 n { f i ( w ) = E ξ [ f i ( w , ξ i ) ] } ] \min_{\mathcal{w}\in \mathbb{R^d}}[F(w)=\frac{1}{n}\sum_{i=1}^n\{f_i(w)=\mathbb{E_\xi[f_i(w,\xi_i)]}\}] wRdmin[F(w)=n1i=1n{ fi(w)=Eξ[fi(w,ξi)]}]
而客户端采取上面的方式(个性化)
min ⁡ v ∈ R d f i ( α i v + ( 1 − a l p h a i ) w ∗ ) \min_{\mathcal{v}\in \mathbb{R^d}}f_i(\alpha_iv+(1-alpha_i)w^*) vRdminfi(αiv+(1alphai)w)
其中 w ∗ = a r g min ⁡ w F ( w ) w^*=arg\min_w F(w) w=argminwF(w)
具体步骤如下:
在这里插入图片描述
对于参与训练的客户端来说,存在着两个参数。一个是w:全局参数,一个是v:自己的参数。首先根据数据集对w进行更新(用t-1轮的参数)。对v更新方式也是如此,v通过混合参数( v ˉ \bar{v} vˉ)计算梯度来进行更新,之后将新的w和v合成我们当前的混合参数,再把w传到服务端进行合并。

3.4 α \alpha α的取值

直观来看,当本地数据较为均匀,每个客户端的局部模型接近全局模型时我们需要较小的 α \alpha α;相反,当本地数据多样性较强时, α \alpha α应该接近1。我们需要再不同分布的场景下更新我们的 α \alpha α
α i ∗ = a r g min ⁡ α i ∈ [ 0 , 1 ] f i ( α i v + ( 1 − α i ) w ) \alpha^*_i = arg\min_{\alpha_i \in[0,1]}f_i(\alpha_iv+(1-\alpha_i)w) αi=argαi[0,1]minfi(αiv+(1αi)w)
我们可以使用梯度下降来更新一次 α \alpha α
α i ( t ) = α i ( t − 1 ) − η t ∇ α f i ( v ˉ i ( t − 1 ) ; ξ i t ) = α i ( t − 1 ) − η t < v i ( t − 1 ) − w i ( t − 1 ) , ∇ f i ( v ˉ i ( t − 1 ) ; ξ i t ) > \begin{aligned} \alpha_i^{(t)}&=\alpha_i^{(t-1)}-\eta_t \nabla_\alpha f_i(\bar{v}_i^{(t-1)};\xi_i^t)\\ &=\alpha_i^{(t-1)}-\eta_t <v_i^{(t-1) }-w_i^{(t-1)},\nabla f_i(\bar{v}_i^{(t-1)};\xi_i^t)> \end{aligned} αi(t)=αi(t1)ηtαfi(vˉi(t1);ξit)=αi(t1)ηt<vi(t1)wi(t1),fi(vˉi(t1);ξit)>

四. 关键代码解析

作者的代码github地址点这里,这个github还包括很多其他的联邦学习算法,这里只针对APFL算法进行讲解。
APFL主要不同在客户端的更新上,因此我们针对客户端训练进行解读。
首先是对全局模型参数w的更新
直接读取数据,求损失,用SGD更新

_input, _target = load_data_batch(client.args, _input, _target, tracker)
# Skip batches with one sample because of BatchNorm issue in some models!
if _input.size(0)==1:
    is_sync = is_sync_fed(client.args)
    break

# inference and get current performance.
client.optimizer.zero_grad()
loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)

# compute gradient and do local SGD step.
loss.backward()
client.optimizer.step(
    apply_lr=True,
    apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
)

接下来是本地模型参数的更新v:

client.optimizer_personal.zero_grad()
loss_personal, performance_personal = inference_personal(client.model_personal, client.model, 
                                                         client.args.fed_personal_alpha, client.criterion, 
                                                         client.metrics, _input, _target)

# compute gradient and do local SGD step.
loss_personal.backward()
client.optimizer_personal.step(
    apply_lr=True,
    apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
)

也是一样的,拿到一个batch数据,求损失,注意这里求损失对应的参数为上一轮的混合参数,而不仅仅是本地参数,求混合参数损失代码如下:
其实就是用 α \alpha α来合成混合参数算损失

def inference_personal(model1, model2, alpha, criterion, metrics, _input, _target):
    """Inference on the given model and get loss and accuracy."""
    # TODO: merge with inference
    output1 = model1(_input)
    output2 = model2(_input)
    output = alpha * output1 + (1-alpha) * output2
    loss = criterion(output, _target)
    performance = accuracy(output.data, _target, topk=metrics)
    return loss, performance

到这里其实就实现了APFL,但还有个关键的地方在于,每一轮在训练前更细呢一次 α \alpha α,通过3.4节讲解的方式更新:

def alpha_update(model_local, model_personal,alpha, eta):
    grad_alpha = 0
    for l_params, p_params in zip(model_local.parameters(), model_personal.parameters()):
    	## 这里为 v - w
        dif = p_params.data - l_params.data
        ## 这里为f(\bar{v}的损失)
        grad = alpha * p_params.grad.data + (1-alpha)*l_params.grad.data
        ## 乘起来即可
        grad_alpha += dif.view(-1).T.dot(grad.view(-1))
    grad_alpha += 0.02 * alpha
    ## 进行更新
    alpha_n = alpha - eta*grad_alpha
    ## 确保在0,1之间
    alpha_n = np.clip(alpha_n.item(),0.0,1.0)
    return alpha_n

到这里,APFL算法就介绍完了,最后附上apfl整个训练的代码,方便大家查看:

def train_and_validate_federated_apfl(client):
    """The training scheme of Personalized Federated Learning.
        Official implementation for https://arxiv.org/abs/2003.13461
    """
    log('start training and validation with Federated setting.', client.args.debug)

    if client.args.evaluate and client.args.graph.rank==0:
        # Do the testing on the server and return
        do_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,
                         client.test_loader, client.all_clients_group, data_mode='test')
        return

    
    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c), client.args.debug)
        online_clients = set_online_clients(client.args)
        if (n_c == 0) and  (0 not in online_clients):
            online_clients += [0]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [0]
        online_clients_group = dist.new_group(online_clients_server)
        
        if client.args.graph.rank in online_clients_server: 
            client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0)
            client.model.load_state_dict(client.model_server.state_dict())
            if client.args.graph.rank in online_clients:
                is_sync = False
                ep = -1 # counting number of epochs
                while not is_sync:
                    ep += 1
                    for i, (_input, _target) in enumerate(client.train_loader):
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler)

                        # load data
                        _input, _target = load_data_batch(client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0)==1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)

                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
                        )
                        
                        client.optimizer.zero_grad()
                        client.optimizer_personal.zero_grad()
                        loss_personal, performance_personal = inference_personal(client.model_personal, client.model, 
                                                                                 client.args.fed_personal_alpha, client.criterion, 
                                                                                 client.metrics, _input, _target)

                        # compute gradient and do local SGD step.
                        loss_personal.backward()
                        client.optimizer_personal.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
                        )

                        # update alpha
                        if client.args.fed_adaptive_alpha and i==0 and ep==0:
                            client.args.fed_personal_alpha = alpha_update(client.model, client.model_personal, client.args.fed_personal_alpha, lr) #0.1/np.sqrt(1+args.local_index))
                            average_alpha = client.args.fed_personal_alpha
                            average_alpha = global_average(average_alpha, client.args.graph.n_nodes, group=online_clients_group)
                            log("New alpha is:{}".format(average_alpha.item()), client.args.debug)
                        
                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)
                        
                        # display the logging info.
                        # logging_display_training(args, tracker)
                        
                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!", client.args.debug)

            do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, 
                        client.train_loader, online_clients_group, data_mode='train', personal=True, 
                        model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)
            if client.args.fed_personal:
                do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, 
                            client.val_loader, online_clients_group, data_mode='validation', personal=True, 
                            model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)

            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1
            client.model_server = fedavg_aggregation(client.args, client.model_server, client.model, 
                                                     online_clients_group, online_clients, client.optimizer)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, 
                        client.train_loader, online_clients_group, data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, 
                            client.val_loader, online_clients_group, data_mode='validation')

           
            # logging.
            logging_globally(tracker, start_global_time)
            
            # reset start round time.
            start_global_time = time.time()

            # validate the models at the test data
            if client.args.fed_personal_test:
                do_validate(client.args, client.model_client, client.optimizer_personal, client.criterion, 
                            client.metrics, client.test_loader, online_clients_group, data_mode='test', personal=True,
                            model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)
            elif client.args.graph.rank == 0:
                do_validate(client.args, client.model_server, client.optimizer, client.criterion, 
                            client.metrics, client.test_loader, online_clients_group, data_mode='test')
        else:
            log("Offline in this round. Waiting on others to finish!", client.args.debug)
        dist.barrier(group=client.all_clients_group)

注意,这里求 α \alpha α是在每一轮训练的第一个batch之后进行更新,我觉得目的是防止一开始的w和v初始化的结果影响太大,因此改为训练一个batch后更新。

猜你喜欢

转载自blog.csdn.net/qq_45478482/article/details/121568515