Registre un proceso de depuración implementado por el algoritmo FedAvg de aprendizaje federado

proceso

En el experimento FedAvg bajo el marco federado de tensorflow, un fragmento de código siempre detecta una salida anormal:

for predicted_y in batch_predicted_y:
   max = -1000
   flag = -1
   for j in range(10):
       if predicted_y[j] > max:
           max = predicted_y[j]
           flag = j
   if(flag==-1):
       sys.exit()
   p.append(flag)

Y este código está en la función manuscrita eval(model)para calcular la precisión del modelo.

 #验证模型准确率
    loss,acc = eval(federated_model)

Así que pasé 30 minutos comprobando esta función y no encontré ningún error.

Finalmente, desde el principio hasta el final de la inspección de una función, se encuentra que el problema no es la forma en que está escrita la función, sino que el parámetro del modelo pasado es NaN:
Inserte la descripción de la imagen aquí
Entonces, de acuerdo con la depuración de la pila de llamadas nuevamente, batch_train()el resultado encontrado es normal, y local_train()el resultado es NaN


def batch_train(initial_model, batch, learning_rate):
    model_vars = collections.OrderedDict([
        (name, tf.Variable(name=name, initial_value=value))
        for name, value in initial_model.items()
    ])
    optimizer = tf.keras.optimizers.SGD(learning_rate)
    def _train_on_batch(model_vars, batch):     
        with tf.GradientTape() as tape:
            loss = forward_pass(model_vars, batch)
        grads = tape.gradient(loss, model_vars)
        optimizer.apply_gradients(
            zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
        return model_vars
    return _train_on_batch(model_vars, batch)
    
# 设备训练
def local_train(initial_model, learning_rate, all_batches):
    def batch_fn(model, batch):
        return  batch_train(model, batch, learning_rate)
    model=initial_model
    model=tff.sequence_reduce(all_batches,model, batch_fn)
    return model

Luego, piense en local_train()el significado de la batch_train()función: llama a la función para cada lote y el modelo se actualiza constantemente. Entonces, el problema radica en el proceso iterativo de actualización del modelo: un parámetro de iteración cambia a NaN, que puede ser un número más allá del rango de representación.

Así que intenté modificar la tasa de aprendizaje de 0,1 a 0,01 y, finalmente, los parámetros del modelo fueron normales.

Inserte la descripción de la imagen aquí

for round_num in range(ROUND):
    learning_rate = 0.01 / (1+round_num)

El entrenamiento posterior también es normal, pero la velocidad de aprendizaje es demasiado baja y el aprendizaje es demasiado lento. Puede ajustar la velocidad de aprendizaje de forma adecuada.
Inserte la descripción de la imagen aquí

para resumir

Me he encontrado con una situación en la que los parámetros del modelo se convierten en NaN debido a una configuración incorrecta de la tasa de aprendizaje. Esta vez no esperaba que la mayor parte del tiempo se dedicara a verificar si la función estaba escrita incorrectamente. No se encontraron resultados. Finalmente, se realizó una búsqueda de alfombra utilizado para encontrar el error. En mi enojo, el blog de Shuibian registra las horas perdidas por depuración y luego agrega un resumen en profundidad del impacto de la tasa de aprendizaje en el entrenamiento de modelos (y cavando agujeros).

Supongo que te gusta

Origin blog.csdn.net/Protocols7/article/details/112000634
Recomendado
Clasificación