import matplotlib.pyplot as plt
给出关键代码
mNadam = Adam(lr=1e-4, beta_1=0.95, beta_2=0.96)
deep.compile(optimizer=mNadam, loss='binary_crossentropy',
metrics=['AUC', 'Precision', 'Recall'])
batch_size = 20000
train_nums = len(data)
history = deep.fit_generator(
GeneratorRandomPatchs(train_x, train_y, batch_size, train_nums, feature_names),
validation_data=(val_x, val_y),
steps_per_epoch=train_nums // batch_size,
epochs=10,
verbose=1,
shuffle=True,
# callbacks=[earlystop_callback]
)
print(history.history)
# {'loss': [0.5298165678977966], 'auc': [0.6162796020507812], 'precision': [0.3609336018562317], 'recall': [0.13993297517299652], 'val_loss': [0.5318868160247803], 'val_auc': [0.6150773167610168], 'val_precision': [0.2931766211986542], 'val_recall': [0.1587975174188614]}
visualization(history, flag=True, path1=loss_plt_path.format('loss_auc.jpg'), path2=loss_plt_path.format('precision_recall.jpg'))
''' 可视化 '''
def visualization(history, flag=True, path1='', path2=''):
loss = history.history['loss']
val_loss = history.history['val_loss']
auc = history.history['auc']
val_auc = history.history['val_auc']
pre = history.history['precision']
val_pre = history.history['val_precision']
recall = history.history['recall']
val_recall = history.history['val_recall']
plt.subplot(1, 2, 1)
plt.plot(loss, label='Training loss')
plt.plot(val_loss, label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(auc, label='Training auc')
plt.plot(val_auc, label='Validation auc')
plt.title('Training and Validation auc')
plt.legend()
if flag:
plt.savefig(path1)
plt.show()
plt.subplot(1, 2, 1)
plt.plot(pre, label='Training precision')
plt.plot(val_pre, label='Validation precision')
plt.title('Training and Validation precision')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(recall, label='Training recall')
plt.plot(val_recall, label='Validation recall')
plt.title('Training and Validation recall')
plt.legend()
if flag:
plt.savefig(path2)
plt.show()