pytorch绘制loss和accuracy曲线

1.前言

pytorch虽然使用起来很方便,但在一点上并没有tensorflow方便,就是绘制模型训练时在训练集和验证集上的loss和accuracy曲线(共四条)。tensorflow模型训练时,每次epoch的模型,以及在训练集和验证集上的loss和acc都保存在一个对象中,当我们要绘制四条曲线时,直接从对象中取值即可。

2.Loss曲线

Loss_list = []  #存储每次epoch损失值
def draw_loss(Loss_list,epoch):
    # 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
    plt.cla()
    x1 = range(1, epoch+1)
    print(x1)
    y1 = Loss_list
    print(y1)
    plt.title('Train loss vs. epoches', fontsize=20)
    plt.plot(x1, y1, '.-')
    plt.xlabel('epoches', fontsize=20)
    plt.ylabel('Train loss', fontsize=20)
    plt.grid()
    plt.savefig("./lossAndacc/Train_loss.png")
    plt.savefig("./lossAndacc/Train_loss.png")
    plt.show()

3.acc曲线

def draw_fig(list,name,epoch):
    # 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
    x1 = range(1, epoch+1)
    print(x1)
    y1 = list
    if name=="loss":
        plt.cla()
        plt.title('Train loss vs. epoch', fontsize=20)
        plt.plot(x1, y1, '.-')
        plt.xlabel('epoch', fontsize=20)
        plt.ylabel('Train loss', fontsize=20)
        plt.grid()
        plt.savefig("./lossAndacc/Train_loss.png")
        plt.show()
    elif name =="acc":
        plt.cla()
        plt.title('Train accuracy vs. epoch', fontsize=20)
        plt.plot(x1, y1, '.-')
        plt.xlabel('epoch', fontsize=20)
        plt.ylabel('Train accuracy', fontsize=20)
        plt.grid()
        plt.savefig("./lossAndacc/Train _accuracy.png")
        plt.show()

这里我把绘制loss和acc曲线的代码进行了合并。
用法如下,测试模型在验证集上的loss和acc时,让结果返回两个list对象,分别存储了每次epoch时的loss和acc的值。然后调用draw_fig方法,把对象作为参数传递进去。

if __name__ == '__main__':
    # val(model)

    with torch.no_grad():
        criterion = nn.BCEWithLogitsLoss().cuda()
        epoch = 30
        loss=[]
        acc=[]
        for i in range(1, epoch + 1):
            dir = "./result/20201029_2110/checkpoints/" + str(i) + ".pth"
            model = torch.load(dir)
            model.eval()  # 需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值
            loss1,acc1=auto_val(model,criterion)
            loss.append(loss1)
            acc.append(acc1)
        draw_fig(loss,"loss",epoch)
        draw_fig(acc,"acc",epoch)

猜你喜欢

转载自blog.csdn.net/t18438605018/article/details/121895338
今日推荐