Dieses Artikelverzeichnis
Ursprung: Kommunikationseffizientes Lernen tiefer Netzwerke aus dezentralisierten Daten
1. Einführung in den cifar10-Datensatz
CIFAR-10 ist ein Farbbilddatensatz, der näher an universellen Objekten liegt. CIFAR-10 ist ein kleiner Datensatz für die allgegenwärtige Objekterkennung, der von Hintons Studenten Alex Krizhevsky und Ilya Sutskever zusammengestellt wurde. Insgesamt 10 Kategorien von RGB-Farbbildern: Flugzeug (Flugzeug), Auto (Automobil), Vogel (Vogel), Katze (Katze), Hirsch (Hirsch), Hund (Hund), Frosch (Frosch), Pferd (Pferd), Schiff (Schiff) und LKW (Lastwagen). Die Größe jedes Bildes beträgt 32 × 32, jede Kategorie hat 6000 Bilder und es gibt 50000 Trainingsbilder und 10000 Testbilder im Datensatz.
2. Einführung in Federated Learning
Daten für FL müssen die folgenden Eigenschaften (Kriterien) haben:
- Das Training mit echten Daten von mobilen Geräten hat klare Vorteile gegenüber Proxy-Daten, die von Rechenzentren bereitgestellt werden;
- Die Daten sind datenschutzrelevant oder umfangreich und müssen nicht nur zum Trainieren des Modells im Rechenzentrum aufgezeichnet werden;
- Für überwachte Aufgaben können die Bezeichnungen der Daten natürlich aus Benutzerinteraktionen abgeleitet werden.
Die Daten sind sensibel: das Foto des Benutzers oder der über die Tastatur eingegebene Text;
Die Verteilung von Daten unterscheidet sich auch von der durch Proxy-Daten bereitgestellten und hat mehr Benutzereigenschaften und Vorteile;
Daten-Tags können auch direkt erhalten werden : Beispielsweise werden die Fotos und Texteingaben des Benutzers markiert; Fotos können durch Benutzerinteraktion markiert (löschen, teilen, anzeigen) werden.
Föderierter Lernprozess:
- Wählen Sie zu Beginn jeder Aktualisierungsrunde zufällig einige Clients aus, deren Größe C-Fraktion ist (sollte ein Verhältnis sein, C≤1);
- Der Server sendet dann den aktuellen Zustand des globalen Algorithmus an diese Clients (z. B. aktuelle Modellparameter);
- Jeder Client führt dann lokale Berechnungen basierend auf dem globalen Zustand und seinem lokalen Datensatz durch und sendet Aktualisierungen an den Server;
- Schließlich wendet der Server diese Aktualisierungen auf seinen globalen Status an, und der Vorgang wird wiederholt.
3. Clientseitiges CNN-Modell
import torch.nn.functional as F
from torch import nn
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
class CNN(nn.Module):
def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=n_kernels, kernel_size=5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(in_channels=n_kernels, out_channels=2 * n_kernels, kernel_size=5)
self.fc1 = nn.Linear(in_features=2 * n_kernels * 5 * 5, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.fc3 = nn.Linear(in_features=84, out_features=out_dim)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class Client(object):
def __int__(self, trainDataSet, dev):
self.train_ds = trainDataSet
self.dev = dev
self.train_dl = None
self.local_parameter = None
3. FedAvg-Implementierung
def evaluate(net, global_parameters, testDataLoader, dev):
net.load_state_dict(global_parameters, strict=True)
running_correct = 0
running_samples = 0
net.eval()
# 载入测试集
for data, label in testDataLoader:
data, label = data.to(dev), label.to(dev)
pred = net(data)
running_correct += pred.argmax(1).eq(label).sum().item()
running_samples += len(label)
print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")
def local_upload(train_data_set, local_epoch, net, loss_fun, opt, global_parameters, dev):
# 加载当前通信中最新全局参数
net.load_state_dict(global_parameters, strict=True)
# 设置迭代次数
net.train()
for epoch in range(local_epoch):
for data, label in train_data_set:
data, label = data.to(dev), label.to(dev)
# 模型上传入数据
predict = net(data)
loss = loss_fun(predict, label)
# 反向传播
loss.backward()
# 计算梯度,并更新梯度
opt.step()
# 将梯度归零,初始化梯度
opt.zero_grad()
# 返回当前Client基于自己的数据训练得到的新的模型参数
return net.state_dict()
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
steps: int, node_iter: int, optim: str, lr: float, inner_lr: float,
embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
seed: int) -> None:
###############################
# init nodes, hnet, local net #
###############################
steps = 5
node_iter = 5
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
batch_size=bs)
net = CNN(n_kernels=n_kernels)
# hnet = hnet.to(device)
net = net.to(device)
##################
# init optimizer #
##################
# embed_lr = embed_lr if embed_lr is not None else lr
optimizer = torch.optim.SGD(
net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
)
criteria = torch.nn.CrossEntropyLoss()
################
# init metrics #
################
# step_iter = trange(steps)
step_iter = range(steps)
# train process
# record the global parameters
global_parameters = {
}
for key, parameter in net.state_dict().items():
global_parameters[key] = parameter.clone()
for step in step_iter:
local_parameters_list = {
}
# 需要训练的node数目
for i in range(node_iter):
# 随机选择一个客户端
node_id = random.choice(range(num_nodes))
# 用全局模型参数训练当前客户端
local_parameters = local_upload(nodes.train_loaders[node_id], 5, net, criteria, optimizer,
global_parameters, dev='cpu')
print("\nEpoch: {}, Node Count: {}, Node ID: {}".format(step + 1, i + 1, node_id), end="")
evaluate(net, local_parameters, nodes.val_loaders[node_id], 'cpu')
local_parameters_list[i] = local_parameters
# 更新当前轮次模型的参数
sum_parameters = None
for node_id, parameters in local_parameters_list.items():
if sum_parameters is None:
sum_parameters = parameters
else:
for key in parameters.keys():
sum_parameters[key] += parameters[key]
for var in global_parameters:
global_parameters[var] = (sum_parameters[var] / node_iter)
# test
net.load_state_dict(global_parameters, strict=True)
net.eval()
for data_set in nodes.test_loaders:
running_correct = 0
running_samples = 0
for data, label in data_set:
pred = net(data)
running_correct += pred.argmax(1).eq(label).sum().item()
running_samples += len(label)
print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")