tf 训练可视化

import matplotlib.pyplot as plt

给出关键代码

mNadam = Adam(lr=1e-4, beta_1=0.95, beta_2=0.96)
deep.compile(optimizer=mNadam, loss='binary_crossentropy',
            metrics=['AUC', 'Precision', 'Recall'])
            
batch_size = 20000
train_nums = len(data)
history = deep.fit_generator(
    GeneratorRandomPatchs(train_x, train_y, batch_size, train_nums, feature_names),
    validation_data=(val_x, val_y),
    steps_per_epoch=train_nums // batch_size,
    epochs=10,
    verbose=1,
    shuffle=True,
    # callbacks=[earlystop_callback]
)

print(history.history)    
# {'loss': [0.5298165678977966], 'auc': [0.6162796020507812], 'precision': [0.3609336018562317], 'recall': [0.13993297517299652], 'val_loss': [0.5318868160247803], 'val_auc': [0.6150773167610168], 'val_precision': [0.2931766211986542], 'val_recall': [0.1587975174188614]}

visualization(history, flag=True, path1=loss_plt_path.format('loss_auc.jpg'), path2=loss_plt_path.format('precision_recall.jpg'))
'''   可视化   '''
def visualization(history, flag=True, path1='', path2=''):
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    auc = history.history['auc']
    val_auc = history.history['val_auc']
    pre = history.history['precision']
    val_pre = history.history['val_precision']
    recall = history.history['recall']
    val_recall = history.history['val_recall']

    plt.subplot(1, 2, 1)
    plt.plot(loss, label='Training loss')
    plt.plot(val_loss, label='Validation loss')
    plt.title('Training and Validation loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(auc, label='Training auc')
    plt.plot(val_auc, label='Validation auc')
    plt.title('Training and Validation auc')
    plt.legend()
    if flag:
        plt.savefig(path1)
        plt.show()

    plt.subplot(1, 2, 1)
    plt.plot(pre, label='Training precision')
    plt.plot(val_pre, label='Validation precision')
    plt.title('Training and Validation precision')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(recall, label='Training recall')
    plt.plot(val_recall, label='Validation recall')
    plt.title('Training and Validation recall')
    plt.legend()
    if flag:
        plt.savefig(path2)
        plt.show()

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/121394030
tf
今日推荐