sklearn stepping on confusion_matrix and KFold

sklearn.metrics.confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

        Although sklearn.metrics.confusion_matrix is ​​very convenient to get the confusion matrix directly, but because I found a problem after practice: sklearn's confusion matrix implementation will automatically reduce the dimension to 1*1 when there is only one class , so I made one by myself:

def calculate_metric(gt, pred): 
    pred[pred>0.5]=1
    pred[pred<1]=0
    TP, FP, TN, FN = 0, 0, 0, 0

    for i in range(len(gt)):
        if gt[i] == 1 and pred[i] == 1:
           TP += 1
        if gt[i] == 0 and pred[i] == 1:
           FP += 1
        if gt[i] == 0 and pred[i] == 0:
           TN += 1
        if gt[i] == 1 and pred[i] == 0:
           FN += 1

    # confusion = confusion_matrix(gt,pred)
    # print(confusion.shape)
    # TP = confusion[1,1]
    # TN = confusion[0,0]
    # FP = confusion[0,1]
    # FN = confusion[1,0]

    return TP, FP, TN, FN

        But later it was found that the confusion_matrix was not to blame for the problem. In fact, it was because of the small number of data samples, which led to the fact that each fold in the 50-fold cross-validation cannot guarantee that each fold contains two types of samples. Fortunately:

 So I went to understand KFold and StratifiedKFold, the usage of both is the same, just modify the function name:

from sklearn.model_selection import train_test_split, KFold, StratifiedKFold


kf = KFold(n_splits=5,random_state=2023,shuffle=True)

kf = StratifiedKFold(n_splits=5,random_state=2023,shuffle=True)

But there are some differences between the two when doing split

#KFold不需要传入标签
for train_index, validate_index in kf.split(dataset):  
    pass
#StratifiedKFold需要传入标签
for train_index, validate_index in kf.split(dataset,dataset['label']):  
    pass

 Switching to StratifiedKFold stratified sampling will not have one fold and only one class, thus avoiding the dimensionality reduction problem of confusion_matrix, so far the problem is solved.

Guess you like

Origin blog.csdn.net/weixin_48144018/article/details/129663521