Code Interpretation and Harvesting of the Pioneering Work of Federated Learning

Reference: Interpretation of federated learning code, super detailed

Reference: [1602.05629] Communication-Efficient Learning of Deep Networks from Decentralized Data (arxiv.org)

参考代码:GitHub - AshwinRJ/Federated-Learning-PyTorch: Implementation of Communication-Efficient Learning of Deep Networks from Decentralized Data


        Try reading the code of the pioneering work today.

Table of contents

1. Load parameters——options.py        

2. Data IID and non-IID sampling - sampling.py

1.mnist_iid()

2.mnist_nonid()

3.mnist_nonid()

4.cifar_iid()、cifar_noniid()

3. Local model parameter update - update.py

1.DatasetSplit(Dataset)

2.LocalUpdate(object)

 3.test_inference(self,model)

4. Application set - utils.py

1.get_dataset(args)

2.average_weights(w)

3.exp_details(args)

5. Model settings - models.py

1. MLP multi-layer perceptron model

2. CNN convolutional neural network

3. Create your own model

Six, the main function - federated_main.py

7. Drawing

8. Personal summary


1. Load parameters——options.py        

import argparse


def args_parser():
    parser = argparse.ArgumentParser()

    # federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--epochs', type=int, default=10,
                        help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=100,
                        help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1,
                        help='the fraction of clients: C')
    parser.add_argument('--local_ep', type=int, default=10,
                        help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10,
                        help="local batch size: B")
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.5,
                        help='SGD momentum (default: 0.5)')

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9,
                        help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to \
                        use for convolution')
    parser.add_argument('--num_channels', type=int, default=1, help="number \
                        of channels of imgs")
    parser.add_argument('--norm', type=str, default='batch_norm',
                        help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32,
                        help="number of filters for conv nets -- 32 for \
                        mini-imagenet, 64 for omiglot.")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than \
                        strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name \
                        of dataset")
    parser.add_argument('--num_classes', type=int, default=10, help="number \
                        of classes")
    parser.add_argument('--gpu', default=None, help="To use cuda, set \
                        to a specific GPU ID. Default set to use CPU.")
    parser.add_argument('--optimizer', type=str, default='sgd', help="type \
                        of optimizer")
    parser.add_argument('--iid', type=int, default=1,
                        help='Default set to IID. Set to 0 for non-IID.')
    parser.add_argument('--unequal', type=int, default=0,
                        help='whether to use unequal data splits for  \
                        non-i.i.d setting (use 0 for equal splits)')
    parser.add_argument('--stopping_rounds', type=int, default=10,
                        help='rounds of early stopping')
    parser.add_argument('--verbose', type=int, default=1, help='verbose')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    args = parser.parse_args()
    return args

        Here, argparse is used to input three types of parameters, namely federation parameters , model parameters , and other parameters . Among them, the federation parameters:

  • epochs : number of training rounds, 10
  • num_users : number of users K, default 100
  • frac : the user chooses the ratio C, the default is 0.1
  • local_ep : local training quantity E, default 10
  • local_bs : local training batch B, default 10
  • lr : learning rate, default 0.01
  • momentum : SGD momentum (why does SGD have momentum?), default 0.5

        Model parameters:

  • model : model name, the default is mlp, which is a fully connected neural network
  • kernel_num : number of convolution kernels, 9 by default
  • kernel_sizes : Convolution kernel size, default 3, 4, 5
  • num_channels : number of image channels, default 1
  • norm : normalization method, can be BN and LN
  • num_filters : number of filters, default 32
  • max_pool : maximum pooling, default is True

        other settings:

  • dataset : what dataset to choose, default mnist
  • num_class : number of categories, default 10
  • gpu : used by default, you can fill in the specific cuda number
  • optimizer : optimizer, the default is the SGD algorithm
  • iid : independent and identically distributed, the default is 1, that is, independent and identically distributed
  • unequal : Whether to distribute the dataset equally, the default is 0, that is
  • stopping_rounds : the number of stopping rounds, the default is 10
  • verbose : log display, 0 does not output, 1 outputs log with progress bar, 2 outputs log without progress bar
  • seed : random number seed, default 1

        Finally, the args_parser() function returns args, which contains the parameters entered by the console.


2. Data IID and non-IID sampling - sampling.py

        This file collects IID and non-IID data from mnist and cifar-10.

1.mnist_iid()

def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

        Randomly select 600 random samples of 100 users.

2.mnist_nonid()

def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 2 shards/client
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users
  • num_shards: Divide 60,000 training set pictures into 200 parts
  • [i for i in range()]: can generate an incremental list
  • {i: np.array([]) for i in range(num_users)}: generate a dictionary of 100 users in braces
  • np.vstack ((idxs, labels)): Stack the numbers and labels to form an array of (2,60000 )
  • idxs_labels = idxs_labels[:, idxs_labels[1, :]. argsort ()]: The function of the argsort function is to output the index array values ​​of the elements in the array sorted from small to large

        After screening, the label index idxs from small to large is obtained. Then perform user sharding.

  • np.random.choice (): Select two serial numbers from the serial number of the slice, and the replace parameter indicates that the sampling will not be replaced
  • idxs[rand*num_imgs:(rand+1)*num_imgs]: take 300 consecutive sorted index numbers
  • np. concatenate (): Which dimension will be added from which dimension is spelled. Here, two random numbers are randomly selected from 200 index numbers, and the data corresponding to the two random numbers are concatenated.

        Finally, the function returns a dictionary of each user and the corresponding 600 data.

3.mnist_nonid()

def mnist_noniid_unequal(dataset, num_users):
"""
    Sample non-I.I.D client data from MNIST dataset s.t clients
    have unequal amount of data
    :param dataset:
    :param num_users:
    :returns a dict of clients with each clients assigned certain
    number of training imgs
    """

        It's a bit long, so I'll say it separately. Divide 60,000 pieces of data into 1200 pieces:

    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
    num_shards, num_imgs = 1200, 50
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

        Get sorted index numbers:

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

        Set the range of data copies held by each user:

    # Minimum and maximum shards assigned per client:
    min_shard = 1
    max_shard = 30

        That is to say, each user has at least 1×50=50 pictures and at most 30*50=1500 pictures.

        Next, the 1,200 shares must be allocated to these users, and each user must be allocated at least one piece of data, and each piece of data must be allocated.

    # Divide the shards into random chunks for every client
    # s.t the sum of these chunks = num_shards
    random_shard_size = np.random.randint(min_shard, max_shard+1,
                                          size=num_users)
    random_shard_size = np.around(random_shard_size /
                                  sum(random_shard_size) * num_shards)
    random_shard_size = random_shard_size.astype(int)
  • np.random.randint: returns a list of intervals that are closed before opening and then open, the length is the number of users
  • np.around: rounding, return to even numbers

        After this step, all copies are adjusted proportionally so that the sum is close to 1200. (Because the decimals are rounded, it is not strictly equal to 1200) So the next step is to adjust and allocate this part that is not strict.

    # Assign the shards randomly to each client
    if sum(random_shard_size) > num_shards:

        for i in range(num_users):
            # First assign each client 1 shard to ensure every client has
            # atleast one shard of data
            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        random_shard_size = random_shard_size-1

        # Next, randomly assign the remaining shards
        for i in range(num_users):
            if len(idx_shard) == 0:
                continue
            shard_size = random_shard_size[i]
            if shard_size > len(idx_shard):
                shard_size = len(idx_shard)
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)
    else:

        for i in range(num_users):
            shard_size = random_shard_size[i]
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        if len(idx_shard) > 0:
            # Add the leftover shards to the client with minimum images:
            shard_size = len(idx_shard)
            # Add the remaining shard to the client with lowest data
            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[k] = np.concatenate(
                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

    return dict_users

        Finally, an index dictionary of non-IID data held by randomly assigned users will be obtained.

4.cifar_iid()、cifar_noniid()

        There is no difference, do not write


3. Local model parameter update - update.py

1.DatasetSplit(Dataset)

        Let's first look at the official explanation of the Dataset class: Dataset can be anything, but it always contains a __len__ function (called by the standard function len in Python) and a __getitem__ function used to index into the content.

class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

        This part of the code overrides the Dataset class:

  • The __len__(self) method is rewritten to return the length of the data list, that is, the number of samples in the data set
  • Override the __getitem__(self,item) method to get the tensor of image and label

2.LocalUpdate(object)

        This is the code for updating the model locally, it’s a bit much, I’ll say it separately:

class LocalUpdate(object):...

        The first is the constructor, which first defines the parameters and logs, then obtains the data loader from the train_val_test() function, and then specifies the computing device.

        The more important thing is that the loss function here is the NLL loss function, which is similar to cross-entropy. The only difference is that the result is Softmaxed once in the log of NLL.

    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, list(idxs))
        self.device = 'cuda' if args.gpu else 'cpu'
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device)

        Next is the train_val_test() function, which is used to split the dataset. The input data set and index are divided according to 8:1:1. Note that when specifying the batchsize, except that the training set is specified from the args parameter, both val and test take one-tenth of the total.

    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

        Next is the local weight update function, which inputs the model and the round number of global updates , and outputs the updated weight and loss averages. First an optimizer is chosen, then the training loop starts.

    def update_weights(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.5)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

                if self.args.verbose and (batch_idx % 10 == 0):
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
  • self.logger.add_scalar ('loss', loss.item()): This function is used to save the data in the program, and then use the tensorboard tool for visualization
  • Every time a local round is passed, the current loss is counted, which is used for the final average loss statistics
  • model.state_dict (): It is a method for viewing network parameters in Pytorch, which can be saved as a pth file with torch.save( )

         Next is the evaluation function: inference(self, model). The input is a model, and the exact value and loss value are calculated. The code here is very informative:

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)

            # Inference
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy, loss
  • model.eval ( ): open the evaluation mode of the model
  • torch.max (): The second parameter refers to the dimension, that is, returns the first dimension (that is, the row), where the index of the large value is returned
  • pred_labels.view (-1): The original intention is to automatically adjust the dimension according to another number, but there is only one dimension here, so all the dimension data in X will be converted into one-dimensional and arranged in sequence .
  • torch.eq (): Compares two tensors element-by-element. If the two elements at the same position are the same, it returns True; if they are different, it returns False .

        The function here takes the test set images and labels, calculates the loss after the model produces the results and then accumulates them.

 3.test_inference(self,model)

        It is exactly the same as the inference function in LocalUpdate, except that in addition to args and model, the input parameters here also need to specify test_dataset:

def test_inference(args, model, test_dataset):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    device = 'cuda' if args.gpu else 'cpu'
    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(test_dataset, batch_size=128,
                            shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    return accuracy, loss

4. Application set - utils.py

        Some tool functions are encapsulated here: get_dataset(), average_weights(), exp_details()

1.get_dataset(args)

        get_dataset(args) obtains the corresponding dataset and user data dictionary according to the command console parameters. It's just an if else, so I won't talk about it if it's a bit simple.

2.average_weights(w)

        Return the average of the weights, i.e. perform the federated averaging algorithm:

def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg
  • w : This w is the weight list calculated after multiple rounds of local training. In the default case of the parameter, it is a list with a length of 10, and each element is a dictionary, and each dictionary contains the name of the model parameter (such as layer_input.weight or layer_hidden.bias), and the specific value of its weight.
  • copy.deepcopy (): deep copy, the copied object will not change with the change of the copied object. Here the weight dictionary of the first user is copied.

        Then, loop through each type of parameter, accumulate the value of the corresponding parameter in each user model, and finally take the average to obtain the averaged model.

3.exp_details(args)

        Visual command console parameter args:

def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}\n')
    return

5. Model settings - models.py

        This file sets up some of the more common network models

1. MLP multi-layer perceptron model

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)
  • nn. Dropout (): You know, search if you don’t understand

2. CNN convolutional neural network

        Too many to show.

3. Create your own model

        The original code here is modelC. Under its constructor, the first parameter of super is AllConvNet, which will report an error in the compiler. But here is not a typo, but to allow users to customize.


Six, the main function - federated_main.py

        (The code I posted here is that I changed the comments)

        First the reference of the library:

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details

        Then start the main function directly:

if __name__ == '__main__':
    start_time = time.time()

    # 定义路径
    path_project = os.path.abspath('..')  # 上级目录的绝对路径
    logger = SummaryWriter('../logs')  # python可视化工具

    args = args_parser()  # 输入命令行参数
    exp_details(args)  # 显示命令行参数情况

        Since it is running in debugging state, the parameters have not been changed, and the parameters are as follows:

         Next load the dataset and user data dictionary:

    # 判断GPU是否可用:
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # 加载数据集,用户本地数据字典
    train_dataset, test_dataset, user_groups = get_dataset(args)

        Here, a training set of 60,000, a test set of 10,000, and a user dictionary with a length of 100 will be returned. The user dictionary is a mapping from 100 users to 600 IID training data.

        Then start to build the model, where the model chooses a multi-layer perceptron:

    # 建立模型
    if args.model == 'cnn':
        # 卷积神经网络
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # 多层感知机
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

        The next step is to set up the model for the first round of training and copy the weights:

    # 设置模型进行训练,并传输给计算设备
    global_model.to(device)
    global_model.train()
    print(global_model)

    # 复制权重
    global_weights = global_model.state_dict()

        The model looks like this:

        This is a multi-layer perceptron with 784 input layers, 64 hidden layers, and 10 output layers, and a Dropout of 0.5 is set.

        Then start the formal training:

    # 训练
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch + 1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)  # 随机选比例为frac的用户
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # 联邦平均,更新全局权重
        global_weights = average_weights(local_weights)

        # 将更新后的全局权重载入模型
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # 每轮训练,都要计算所有用户的平均训练精度
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # 每i轮打印全局Loss
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch + 1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy[-1]))
  • To be honest, I don't know what the others are except for train_loss, train_accuracy and print_every
  • tqdm is a powerful progress bar that supports displaying running time and progress in a for loop
  • global_model.train() : set the model to training mode
  • idxs_users : Randomly select the index list of users. Here, the user selection ratio is 0.1, and the total number of users is 100. Then 100×0.1=10 users will be randomly selected to participate in the training.
  • Execute local update : perform local update for the selected user, the dataset index is obtained from user_groups[idx], and record the updated local parameters and loss values
  • Federated averaging : pass the model parameter dictionary into the update function, return the averaged model parameter dictionary, and then load it into the global model

        At the end of each round, the training accuracy of all 100 users is counted, and the global loss value is printed for ieach round .

        (Note, what Global Round and Local Epoch are you running the model that keeps scrolling, it is formed by calling the update_weights method in the LocalUpdate class in update.py, if you don’t want him to scroll so frequently, just comment it out under this function Can)

        After global training, the performance of the model on the test set:

    # 训练后,测试模型在测试集的表现
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))

        result:

         The last is to save the target training loss and training accuracy, and finally output the time.

    # 保存目标训练损失和训练精度
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'. \
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
  • pkl file: pickle.dump (data, f) is for writing, pickle.load (file name) is for reading, where Loss and Accuracy are saved

7. Drawing

        At the end of the code, the author wrote the drawing code with comments:

    # 画图
    import matplotlib
    import matplotlib.pyplot as plt
    matplotlib.use('Agg')

    # 绘制损失曲线
    plt.figure()
    plt.title('训练损失 vs 通信回合数')
    plt.plot(range(len(train_loss)), train_loss, color='r')
    plt.ylabel('训练损失')
    plt.xlabel('通信回合数')
    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
                format(args.dataset, args.model, args.epochs, args.frac,
                       args.iid, args.local_ep, args.local_bs))

    # 平均准度曲线
    plt.figure()
    plt.title('平均准度 vs 通信回合数')
    plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
    plt.ylabel('平均准度')
    plt.xlabel('通信回合数')
    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
                format(args.dataset, args.model, args.epochs, args.frac,
                       args.iid, args.local_ep, args.local_bs))

        Make a picture as follows:

 


8. Personal summary

        I have learned a lot from reading the code this time, including the organization of the code, the application of some libraries, and the most important mechanism of federated learning. The author wrote such a meaningful article with simple and easy-to-understand code. I admire it. . It not only improved my code strength, but also allowed me to officially step into the gate of fl.

Guess you like

Origin blog.csdn.net/m0_51562349/article/details/127432295