FedAvg-Verbundlernaufgabe basierend auf dem cifar10-Datensatz


Ursprung: Kommunikationseffizientes Lernen tiefer Netzwerke aus dezentralisierten Daten

1. Einführung in den cifar10-Datensatz

CIFAR-10 ist ein Farbbilddatensatz, der näher an universellen Objekten liegt. CIFAR-10 ist ein kleiner Datensatz für die allgegenwärtige Objekterkennung, der von Hintons Studenten Alex Krizhevsky und Ilya Sutskever zusammengestellt wurde. Insgesamt 10 Kategorien von RGB-Farbbildern: Flugzeug (Flugzeug), Auto (Automobil), Vogel (Vogel), Katze (Katze), Hirsch (Hirsch), Hund (Hund), Frosch (Frosch), Pferd (Pferd), Schiff (Schiff) und LKW (Lastwagen). Die Größe jedes Bildes beträgt 32 × 32, jede Kategorie hat 6000 Bilder und es gibt 50000 Trainingsbilder und 10000 Testbilder im Datensatz.

2. Einführung in Federated Learning

Daten für FL müssen die folgenden Eigenschaften (Kriterien) haben:

  1. Das Training mit echten Daten von mobilen Geräten hat klare Vorteile gegenüber Proxy-Daten, die von Rechenzentren bereitgestellt werden;
  2. Die Daten sind datenschutzrelevant oder umfangreich und müssen nicht nur zum Trainieren des Modells im Rechenzentrum aufgezeichnet werden;
  3. Für überwachte Aufgaben können die Bezeichnungen der Daten natürlich aus Benutzerinteraktionen abgeleitet werden.

Die Daten sind sensibel: das Foto des Benutzers oder der über die Tastatur eingegebene Text;

Die Verteilung von Daten unterscheidet sich auch von der durch Proxy-Daten bereitgestellten und hat mehr Benutzereigenschaften und Vorteile;

Daten-Tags können auch direkt erhalten werden : Beispielsweise werden die Fotos und Texteingaben des Benutzers markiert; Fotos können durch Benutzerinteraktion markiert (löschen, teilen, anzeigen) werden.
Föderierter Lernprozess:

  1. Wählen Sie zu Beginn jeder Aktualisierungsrunde zufällig einige Clients aus, deren Größe C-Fraktion ist (sollte ein Verhältnis sein, C≤1);
  2. Der Server sendet dann den aktuellen Zustand des globalen Algorithmus an diese Clients (z. B. aktuelle Modellparameter);
  3. Jeder Client führt dann lokale Berechnungen basierend auf dem globalen Zustand und seinem lokalen Datensatz durch und sendet Aktualisierungen an den Server;
  4. Schließlich wendet der Server diese Aktualisierungen auf seinen globalen Status an, und der Vorgang wird wiederholt.
    Bildbeschreibung hier einfügen

3. Clientseitiges CNN-Modell

import torch.nn.functional as F
from torch import nn
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader


class CNN(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=n_kernels, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=n_kernels, out_channels=2 * n_kernels, kernel_size=5)
        self.fc1 = nn.Linear(in_features=2 * n_kernels * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Client(object):
    def __int__(self, trainDataSet, dev):
        self.train_ds = trainDataSet
        self.dev = dev
        self.train_dl = None
        self.local_parameter = None

3. FedAvg-Implementierung

def evaluate(net, global_parameters, testDataLoader, dev):
    net.load_state_dict(global_parameters, strict=True)
    running_correct = 0
    running_samples = 0
    net.eval()
    # 载入测试集
    for data, label in testDataLoader:
        data, label = data.to(dev), label.to(dev)
        pred = net(data)
        running_correct += pred.argmax(1).eq(label).sum().item()
        running_samples += len(label)
    print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")


def local_upload(train_data_set, local_epoch, net, loss_fun, opt, global_parameters, dev):
    # 加载当前通信中最新全局参数
    net.load_state_dict(global_parameters, strict=True)
    # 设置迭代次数
    net.train()
    for epoch in range(local_epoch):
        for data, label in train_data_set:
            data, label = data.to(dev), label.to(dev)
            # 模型上传入数据
            predict = net(data)
            loss = loss_fun(predict, label)
            # 反向传播
            loss.backward()
            # 计算梯度,并更新梯度
            opt.step()
            # 将梯度归零,初始化梯度
            opt.zero_grad()
    # 返回当前Client基于自己的数据训练得到的新的模型参数
    return net.state_dict()


def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
          steps: int, node_iter: int, optim: str, lr: float, inner_lr: float,
          embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
          n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
          seed: int) -> None:
    ###############################
    # init nodes, hnet, local net #
    ###############################
    steps = 5
    node_iter = 5
    nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
                      batch_size=bs)
    net = CNN(n_kernels=n_kernels)
    # hnet = hnet.to(device)
    net = net.to(device)

    ##################
    # init optimizer #
    ##################
    # embed_lr = embed_lr if embed_lr is not None else lr
    optimizer = torch.optim.SGD(
        net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
    )
    criteria = torch.nn.CrossEntropyLoss()

    ################
    # init metrics #
    ################
    # step_iter = trange(steps)
    step_iter = range(steps)
    # train process
    # record  the global parameters
    global_parameters = {
    
    }
    for key, parameter in net.state_dict().items():
        global_parameters[key] = parameter.clone()
    for step in step_iter:

        local_parameters_list = {
    
    }
        # 需要训练的node数目
        for i in range(node_iter):
            # 随机选择一个客户端
            node_id = random.choice(range(num_nodes))
            # 用全局模型参数训练当前客户端
            local_parameters = local_upload(nodes.train_loaders[node_id], 5, net, criteria, optimizer,
                                            global_parameters, dev='cpu')
            print("\nEpoch: {}, Node Count: {}, Node ID: {}".format(step + 1, i + 1, node_id), end="")
            evaluate(net, local_parameters, nodes.val_loaders[node_id], 'cpu')
            local_parameters_list[i] = local_parameters

        # 更新当前轮次模型的参数
        sum_parameters = None
        for node_id, parameters in local_parameters_list.items():
            if sum_parameters is None:
                sum_parameters = parameters
            else:
                for key in parameters.keys():
                    sum_parameters[key] += parameters[key]
        for var in global_parameters:
            global_parameters[var] = (sum_parameters[var] / node_iter)
    # test
    net.load_state_dict(global_parameters, strict=True)
    net.eval()
    for data_set in nodes.test_loaders:
        running_correct = 0
        running_samples = 0
        for data, label in data_set:
            pred = net(data)
            running_correct += pred.argmax(1).eq(label).sum().item()
            running_samples += len(label)
        print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")

Operationsergebnis

Bildbeschreibung hier einfügen

おすすめ

転載: blog.csdn.net/qq_45724216/article/details/125104881