过程
在tensorflow federated框架下的FedAvg实验中,一段代码总是捕捉到异常退出:
for predicted_y in batch_predicted_y:
max = -1000
flag = -1
for j in range(10):
if predicted_y[j] > max:
max = predicted_y[j]
flag = j
if(flag==-1):
sys.exit()
p.append(flag)
而这段代码位于手写的计算模型准确率的函数eval(model)
中
#验证模型准确率
loss,acc = eval(federated_model)
于是我花了30分钟检查这个函数,没查出什么错误。
最后从头到尾一个函数一个函数的检查,发现问题不在于这个函数的写法,而在于传入的模型参数是NaN:
于是又按照调用栈debug,发现batch_train()
的结果是正常的,local_train()
的结果是NaN
def batch_train(initial_model, batch, learning_rate):
model_vars = collections.OrderedDict([
(name, tf.Variable(name=name, initial_value=value))
for name, value in initial_model.items()
])
optimizer = tf.keras.optimizers.SGD(learning_rate)
def _train_on_batch(model_vars, batch):
with tf.GradientTape() as tape:
loss = forward_pass(model_vars, batch)
grads = tape.gradient(loss, model_vars)
optimizer.apply_gradients(
zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
return model_vars
return _train_on_batch(model_vars, batch)
# 设备训练
def local_train(initial_model, learning_rate, all_batches):
def batch_fn(model, batch):
return batch_train(model, batch, learning_rate)
model=initial_model
model=tff.sequence_reduce(all_batches,model, batch_fn)
return model
然后思考local_train()
函数的意义,它是对每个batch调用batch_train()
函数,这其中模型是不断被更新的。那么问题就出在这个模型更新迭代的过程,某一次迭代参数变为了NaN,也许是一个超出表示范围的数。
于是试图修改学习率,从0.1改为0.01,最后模型参数正常。
for round_num in range(ROUND):
learning_rate = 0.01 / (1+round_num)
之后的训练也正常了,只不过学习率太低学的太慢,可以适当再调整学习率。
总结
之前也遇到过因为学习率设置不当,模型参数变为NaN的情况,这次却没想到,将大部分时间花在了检查函数是否写错,无果,最后用地毯式搜索才找出bug。一怒之下水篇博客记录debug损失的几个小时,后续补上关于学习率对模型训练影响的深度总结(又挖坑)。