【Lecture de code】Interprétation de code de test comparatif byzantin

  Aujourd'hui, j'ai commencé à lire en détail le code du test de comparaison d'attaques byzantines.Ce code provient du matériel supplémentaire de FedCut.
Mise à jour 31.3.2023 :
  Le code de ce blog à lecture intensive est lié ici .


1 Algorithme d'agrégation multi-krum

1.1 paquet d'importation

import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp
os.chdir(sys.path[0])
sys.path.append("..")
sys.path.insert(0,'./../utils/')
from utils.logger import *
from utils.eval import *
from utils.misc import *

from cifar10_normal_train import *
from cifar10_util import *
from adam import Adam
from sgd import SGD

Il y a aussi quelques petits épisodes ici. Lors de l'importation des fichiers de code des fichiers frères, VScode ne pourra pas trouver ces fichiers. Après avoir résolu le problème en cours d'exécution, il y aura toujours des problèmes de débogage. La solution est recommandée pour lire l'article écrit précédemment .

1.2 Traiter le jeu de données Cifar10 et le diviser en 50 utilisateurs

# 获取cifar10数据,并以IID的形式划分给50个客户
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='D:/FedAffine/data/cifar' # 自己改一改
# 加载训练集和测试机
train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=train_transform)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=train_transform)

Voici un ensemble de données de chargement commun : 50 000 images d'entraînement et 10 000 images de test.

X=[]
Y=[]
for i in range(len(cifar10_train)):
    X.append(cifar10_train[i][0].numpy())
    Y.append(cifar10_train[i][1])
for i in range(len(cifar10_test)):
    X.append(cifar10_test[i][0].numpy())
    Y.append(cifar10_test[i][1])
X=np.array(X)
Y=np.array(Y)
print('total data len: ',len(X)) #长度60000

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))
X=X[all_indices]
Y=Y[all_indices]

Ici, les données d'image et les étiquettes sont stockées X(60000, 3, 32, 32)et Y(60000,)stockées. Ensuite, un fichier "cifar10_shuffle.pkl" est chargé, qui est l'index mélangé. XAprès cela, la somme est brouillée selon cet index Y.

# data loading
nusers=50
user_tr_len=1000
total_tr_len=user_tr_len*nusers
val_len=5000
te_len=5000
# 取前50000个作为训练集
total_tr_data=X[:total_tr_len]
total_tr_label=Y[:total_tr_len]
# 取之后5000个作为评估集
val_data=X[total_tr_len:(total_tr_len+val_len)]
val_label=Y[total_tr_len:(total_tr_len+val_len)]
# 取最后5000个作为评估集
te_data=X[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
te_label=Y[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
#全部转成tensor
total_tr_data_tensor=torch.from_numpy(total_tr_data).type(torch.FloatTensor)
total_tr_label_tensor=torch.from_numpy(total_tr_label).type(torch.LongTensor)
val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)
te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)

print('total tr len %d | val len %d | test len %d'%(len(total_tr_data_tensor),len(val_data_tensor),len(te_data_tensor)))

Donc, après division, c'est :total tr len 50000 | val len 5000 | test len 5000

user_tr_data_tensors=[]
user_tr_label_tensors=[]
for i in range(nusers): # 遍历五十个users    
    user_tr_data_tensor=torch.from_numpy(total_tr_data[user_tr_len*i:user_tr_len*(i+1)]).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(total_tr_label[user_tr_len*i:user_tr_len*(i+1)]).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

Ce mouvement consiste à allouer 1000 données d'entraînement à chacun des 50 utilisateurs, ce qui vient de se terminer.

1.3 Algorithme de Krum (Multi-Krum) sujet

  Tout d'abord, nous devons savoir quel est le processus de l'algorithme. Ici, j'ai beaucoup cité cet article : Hypothèses de l'algorithme Krum et Multi-Krum
  :

  nous avons nn clients{ c 1 , c 2 , ⋯ , cn } \{c_1,c_2,\cdots,c_n\}{ c1,c2,,cn} , et un Serveur, chaque clientci c_icjeavec les données D je D_iDje. On suppose généralement que les données sont distribuées de manière indépendante et identique, c'est-à-dire IID.
  ici mParmi n clients, il y a fff est un attaquant byzantin, satisfaisant : 2 f + 2 < n 2f+2<n2 f+2<n
  c'est-à-dire quandn = 10 n=10n=A 10 , fff maximum est 3 ; lorsquen = 100 n=100n=100 ,ffLe maximum f est 48.

  Étapes de l'algorithme de Krum :

  1. Le serveur définira le paramètre global WWW est distribué à tous les clients ;
  2. Pour chaque client ci c_icjeSimultanément :
    Calculer le gradient local gi g_igje, puis envoyé au serveur ;
  3. Une fois que le serveur a reçu les gradients du client, il calcule la distance entre n'importe quelle paire de gradients :
    dij = ∥ gi − gj ∥ F 2 d_{ij}=\Vert g_i-g_j\Vert_F^2dje=∥g _jegjeF2
  4. Pour chaque gradient gi g_igje, choisir le n − f − 1 nf-1 le plus proche de luinF1个距离,即{ di , 1 , di , 2 , ⋯ , di , n } \{d_{i,1},d_{i,2},\cdots,d_{i,n}\}{ je , 1,dje , 2,,dje , n} dans le plus petitn − f − 1 nf-1nF1 , autant définir{ di , 1 , di , 2 , ⋯ , di , n − f − 1 } \{d_{i,1},d_{i,2},\cdots,d_{i,nf- 1}\}{ je , 1,dje , 2,,dje , n - F - 1} , puis additionné en tant que score de gradient K r ( i ) = ∑ j = 1 n − f − 1 dij Kr(i)=\sum^{nf-1}_{j=1}d_{ij}Kr ( je )=j = 1n - f - 1dje
  5. Après avoir calculé les scores de tous les gradients, trouvez le gradient g ∗ g^* avec le plus petit scoreg
  6. 更新:W = W − lr × g ∗ W=W-lr\times g^*O=Ol r×g

  Dans la cinquième étape de l'algorithme Multi-Krum, sélectionnez le mm avec le plus petit scorem gradients, le gradient final est cemmLa moyenne de m .
  Vient ensuite la partie code :

# Code for Multi-krum aggregation
def multi_krum(all_updates, n_attackers, multi_k=False):

    candidates = []
    candidate_indices = []
    remaining_updates = all_updates
    all_indices = np.arange(len(all_updates))
	#
    while len(remaining_updates) > 2 * n_attackers + 2:
        torch.cuda.empty_cache()
        distances = []
        for update in remaining_updates:
            distance = []
            for update_ in remaining_updates:
            	# 计算距离
                distance.append(torch.norm((update - update_)) ** 2)
            distance = torch.Tensor(distance).float()
            # None的作用主要是在使用None的位置新增一个维度
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

        distances = torch.sort(distances, dim=1)[0] # 排序
        scores = torch.sum(distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1) # 计算得分,上述算法的第四步
        indices = torch.argsort(scores)[:len(remaining_updates) - 2 - n_attackers] # 返回排序后的值所对应的下标

        candidate_indices.append(all_indices[indices[0].cpu().numpy()]) # 添加一个下标
        all_indices = np.delete(all_indices, indices[0].cpu().numpy())
        candidates = remaining_updates[indices[0]][None, :] if not len(candidates) else torch.cat((candidates, remaining_updates[indices[0]][None, :]), 0)
        remaining_updates = torch.cat((remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0)
        if not multi_k: # 如果不是multi-krum算法,就只取一个分数最好的candidate
            break
    # print(len(remaining_updates))
    aggregate = torch.mean(candidates, dim=0)
    return aggregate, np.array(candidate_indices)

  Le code ci-dessus utilise une structure de boucle plus subtile et réalise les algorithmes krum et multi-krum en même temps.

1.4 Génération de gradient d'attaque de crocs

  Ensuite, le code commence à concevoir le code pour empêcher l'attaque. La première est une fonction outil pour calculer lambda :

# Code for Fang attack on Multi-krum
def compute_lambda_fang(all_updates, model_re, n_attackers):

    distances = []
    n_benign, d = all_updates.shape
    # 计算每个梯度到其他梯度的距离
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1)
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
	# 将所有为零的距离改为10000
    distances[distances == 0] = 10000
    distances = torch.sort(distances, dim=1)[0] # 选出距离最小的点
    scores = torch.sum(distances[:, :n_benign - 2 - n_attackers], dim=1)
    min_score = torch.min(scores)
    term_1 = min_score / ((n_benign - n_attackers - 1) * torch.sqrt(torch.Tensor([d]))[0])
    max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0])

    return (term_1 + max_wre_dist)

  Ensuite, obtenez le gradient d'attaque :

def get_malicious_updates_fang(all_updates, model_re, deviation, n_attackers):

    lamda = compute_lambda_fang(all_updates, model_re, n_attackers)
    threshold = 1e-5

    mal_updates = []
    while lamda > threshold:
        mal_update = (- lamda * deviation)

        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads, krum_candidate = multi_krum(mal_updates, n_attackers, multi_k=False)
        
        if krum_candidate < n_attackers:
            return mal_updates
        
        lamda *= 0.5

    if not len(mal_updates):
        print(lamda, threshold)
        mal_update = (model_re - lamda * deviation)
        
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

1.5 Commencer à tester l'algorithme

Je suppose que tu aimes

Origine blog.csdn.net/m0_51562349/article/details/129560575
conseillé
Classement