记录一次联邦学习FedAvg算法实现的debug过程

过程

在tensorflow federated框架下的FedAvg实验中,一段代码总是捕捉到异常退出:

for predicted_y in batch_predicted_y:
   max = -1000
   flag = -1
   for j in range(10):
       if predicted_y[j] > max:
           max = predicted_y[j]
           flag = j
   if(flag==-1):
       sys.exit()
   p.append(flag)

而这段代码位于手写的计算模型准确率的函数eval(model)

 #验证模型准确率
    loss,acc = eval(federated_model)

于是我花了30分钟检查这个函数,没查出什么错误。

最后从头到尾一个函数一个函数的检查,发现问题不在于这个函数的写法,而在于传入的模型参数是NaN:
在这里插入图片描述
于是又按照调用栈debug,发现batch_train()的结果是正常的,local_train()的结果是NaN


def batch_train(initial_model, batch, learning_rate):
    model_vars = collections.OrderedDict([
        (name, tf.Variable(name=name, initial_value=value))
        for name, value in initial_model.items()
    ])
    optimizer = tf.keras.optimizers.SGD(learning_rate)
    def _train_on_batch(model_vars, batch):     
        with tf.GradientTape() as tape:
            loss = forward_pass(model_vars, batch)
        grads = tape.gradient(loss, model_vars)
        optimizer.apply_gradients(
            zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
        return model_vars
    return _train_on_batch(model_vars, batch)
    
# 设备训练
def local_train(initial_model, learning_rate, all_batches):
    def batch_fn(model, batch):
        return  batch_train(model, batch, learning_rate)
    model=initial_model
    model=tff.sequence_reduce(all_batches,model, batch_fn)
    return model

然后思考local_train()函数的意义,它是对每个batch调用batch_train()函数,这其中模型是不断被更新的。那么问题就出在这个模型更新迭代的过程,某一次迭代参数变为了NaN,也许是一个超出表示范围的数。

于是试图修改学习率,从0.1改为0.01,最后模型参数正常。

在这里插入图片描述

for round_num in range(ROUND):
    learning_rate = 0.01 / (1+round_num)

之后的训练也正常了,只不过学习率太低学的太慢,可以适当再调整学习率。
在这里插入图片描述

总结

之前也遇到过因为学习率设置不当,模型参数变为NaN的情况,这次却没想到,将大部分时间花在了检查函数是否写错,无果,最后用地毯式搜索才找出bug。一怒之下水篇博客记录debug损失的几个小时,后续补上关于学习率对模型训练影响的深度总结(又挖坑)。

猜你喜欢

转载自blog.csdn.net/Protocols7/article/details/112000634