フェデレーション学習FedAvgアルゴリズムによって実装されたデバッグプロセスを記録します

処理する

tensorflowフェデレーションフレームワークでのFedAvg実験では、コードの一部が常に異常な終了をキャッチします。

for predicted_y in batch_predicted_y:
   max = -1000
   flag = -1
   for j in range(10):
       if predicted_y[j] > max:
           max = predicted_y[j]
           flag = j
   if(flag==-1):
       sys.exit()
   p.append(flag)

そして、このコードは、モデルの精度eval(model)計算するための手書き関数に含まれています

 #验证模型准确率
    loss,acc = eval(federated_model)

そのため、この関数を30分間チェックしましたが、エラーは見つかりませんでした。

最後に、関数検査の最初から最後まで、問題は関数の記述方法ではなく、渡されたモデルパラメーターがNaN
ここに画像の説明を挿入
batch_train()あることがわかりますしたがって、コールスタックのデバッグによると、見つかった結果は次のようになります。通常、local_train()結果はNaNです


def batch_train(initial_model, batch, learning_rate):
    model_vars = collections.OrderedDict([
        (name, tf.Variable(name=name, initial_value=value))
        for name, value in initial_model.items()
    ])
    optimizer = tf.keras.optimizers.SGD(learning_rate)
    def _train_on_batch(model_vars, batch):     
        with tf.GradientTape() as tape:
            loss = forward_pass(model_vars, batch)
        grads = tape.gradient(loss, model_vars)
        optimizer.apply_gradients(
            zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
        return model_vars
    return _train_on_batch(model_vars, batch)
    
# 设备训练
def local_train(initial_model, learning_rate, all_batches):
    def batch_fn(model, batch):
        return  batch_train(model, batch, learning_rate)
    model=initial_model
    model=tff.sequence_reduce(all_batches,model, batch_fn)
    return model

次にlocal_train()batch_train()関数の意味を考えます。バッチごとに関数を呼び出し、モデルは常に更新されます。次に、問題はモデルを更新する反復プロセスにあります.1つの反復パラメーターがNaNに変更されます。これは、表現の範囲を超える数である可能性があります。

そこで、学習率を0.1から0.01に変更しようとしましたが、最終的にモデルパラメーターは正常でした。

ここに画像の説明を挿入

for round_num in range(ROUND):
    learning_rate = 0.01 / (1+round_num)

その後のトレーニングも正常ですが、学習率が低すぎて学習が遅すぎます。学習率は適切に調整できます。
ここに画像の説明を挿入

総括する

学習率の設定が不適切なため、モデルパラメータがNaNになる状況に遭遇しました。今回は、関数が正しく記述されていないかどうかの確認にほとんどの時間が費やされるとは思っていませんでした。結果は見つかりませんでした。最後に、カーペット検索を行いました。バグを見つけるために使用されます。私の怒りで、Shuibianブログは、数時間のデバッグ損失を記録し、モデルトレーニング(および穴掘り)に対する学習率の影響の詳細な要約を追加しました。

おすすめ

転載: blog.csdn.net/Protocols7/article/details/112000634