torch动态学习率代码

def test():
    batch_i=0
    mean_loss=0
    last_mean_loss=0

   
 if batch_i % 40 == 39:
    if last_total_loss > 0 and total_loss > last_total_loss*1.01:
        print("total_loss", total_loss)
        adjust_learning_rate(optimizer)
    else:
        print("total_loss",total_loss,last_total_loss)
        last_total_loss = total_loss
    total_loss = torch.sum(loss)
 elif batch_i==0:
    total_loss = torch.sum(loss)
 else:
    total_loss += torch.sum(loss)
if __name__ == '__main__': test()

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/80776711
今日推荐