process
In the FedAvg experiment under the tensorflow federated framework, a piece of code always catches an abnormal exit:
for predicted_y in batch_predicted_y:
max = -1000
flag = -1
for j in range(10):
if predicted_y[j] > max:
max = predicted_y[j]
flag = j
if(flag==-1):
sys.exit()
p.append(flag)
And this code is in the handwritten function eval(model)
to calculate the accuracy of the model
#验证模型准确率
loss,acc = eval(federated_model)
So I spent 30 minutes checking this function and I didn't find any errors.
Finally, from the beginning to the end of a function inspection, it is found that the problem is not the way the function is written, but that the passed model parameter is NaN:
So according to the call stack debug again, batch_train()
the result found is normal, and local_train()
the result is NaN
def batch_train(initial_model, batch, learning_rate):
model_vars = collections.OrderedDict([
(name, tf.Variable(name=name, initial_value=value))
for name, value in initial_model.items()
])
optimizer = tf.keras.optimizers.SGD(learning_rate)
def _train_on_batch(model_vars, batch):
with tf.GradientTape() as tape:
loss = forward_pass(model_vars, batch)
grads = tape.gradient(loss, model_vars)
optimizer.apply_gradients(
zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
return model_vars
return _train_on_batch(model_vars, batch)
# 设备训练
def local_train(initial_model, learning_rate, all_batches):
def batch_fn(model, batch):
return batch_train(model, batch, learning_rate)
model=initial_model
model=tff.sequence_reduce(all_batches,model, batch_fn)
return model
Then think about local_train()
the meaning of the batch_train()
function. It calls the function for each batch , and the model is constantly updated. Then the problem lies in the iterative process of updating the model. One iteration parameter changes to NaN, which may be a number outside the range of representation.
So I tried to modify the learning rate from 0.1 to 0.01, and finally the model parameters were normal.
for round_num in range(ROUND):
learning_rate = 0.01 / (1+round_num)
The subsequent training is also normal, but the learning rate is too low and the learning is too slow. You can adjust the learning rate appropriately.
to sum up
I have encountered a situation where the model parameters become NaN due to improper learning rate settings. This time I did not expect that most of the time was spent checking whether the function was written incorrectly. No results were found. Finally, a carpet search was used to find the bug. . In my anger, the Shuibian blog recorded a few hours of debugging loss, and then added an in-depth summary of the impact of learning rate on model training (and digging holes).