mlr3 actual combat | Classification of liver disease patients based on clinical parameters (7 commonly used machine learning methods)

55c403e94d9106cc87b54ed625348c9d.png

Preamble

The example below is part of an introductory lecture on machine learning at the University of Munich. The goal of this project is to create and compare one or several machine learning pipelines for the problem at hand, while conducting exploratory analysis and elaborating the results.

Prepare

For a detailed guide to mlr3 see:

mlr3 book (https://mlr3book.mlr-org.com/index.html)

## 安装与加载所需包
install.packages('mlr3verse')
install.packages('DataExplorer')
install.packages('gridExtra')
library(mlr3verse)
library(dplyr)
library(tidyr)
library(DataExplorer)
library(ggplot2)
library(gridExtra)

Initialize the random number generator with a fixed seed to guarantee repeatability and reduce the verbosity of the logger to keep the output clean.

set.seed(7832)
lgr::get_logger("mlr3")$set_threshold("warn")
lgr::get_logger("bbotk")$set_threshold("warn")

In this example, the authors investigate specific applications of machine learning algorithms and learners for liver disease detection. Therefore, the task is 二元分类to predict whether a patient has liver disease based on some common diagnostic measures.

Sample data and code to receive: like, read this article, share it to the circle of friends, collect 10 likes and keep it for 30 minutes. Take a screenshot and send the WeChat ID: mzbj0002, or scan the QR code below. 2022 VIP members will receive it for free.

Canoe Notes 2022 VIP Project

rights and interests:

  1. Sample data and code of all tweets in Canoe Notes in 2022 (including most of 2021).

  2. Canoe Notes Scientific Research Exchange Group .

  3. Half-price purchase 跟着Cell学作图系列合集(free tutorial + code collection)|Follow Cell to learn to draw a collection of series .

TOLL:

99¥/person . You can add WeChat: mzbj0002transfer money, or give a reward directly at the end of the article.

60cb072bf217dd8e100e92e88195efbd.png

Liver Disease Data in India

# Importing data
data("ilpd", package = "mlr3data")

583It contains data collected on patients in the northeastern state of Andhra Pradesh, India . Observations were divided into two categories based on whether the patient had liver disease. In addition to our target variable, ten mostly numeric features are provided. To describe these features in more detail, the table below lists the variables in the dataset.

Variable Description
age Age of the patient (all patients above 89 are labelled as 90
gender Sex of the patient (1 = female, 0 = male)
total_bilirubin Total serum bilirubin (in mg/dL)
direct_bilirubin Direct bilirubin level (in mg/dL)
alkaline_phosphatase Serum alkaline phosphatase level (in U/L)
alanine_transaminase Serum alanine transaminase level (in U/L)
aspartate_transaminase Serum aspartate transaminase level (in U/L)
total_protein Total serum protein (in g/dL)
albumin Serum albumin level (in g/dL)
albumin_globulin_ratio Albumin-to-globulin ratio
diseased Target variable (1 = liver disease, 0 = no liver disease)

Obviously, some measurements are part of other variables. For example, total serum bilirubin is the sum of direct and indirect bilirubin levels; while the amount of albumin is used to calculate the value of total serum protein and the albumin-globulin ratio. Therefore, some features are highly correlated with each other and will be dealt with below.

data preprocessing

Univariate distribution

Next, investigate the univariate distribution of each variable. Start with the target variable and the only discrete feature - gender, which are both binary variables.

##  所有离散变量的频率分布
plot_bar(ilpd,ggtheme = theme_bw())
05edafe3b62ce714b9ddb91d15ad3399.png
image-20220408132411502

It can be seen that the distribution of the target variable (i.e. patients with and without liver disease) is quite unbalanced, as shown in the histogram: the number of patients with and without liver disease is 416 and 167, respectively. Under-representation of a class can worsen the performance of an ML model. To investigate this question, the authors also fit the model on a dataset where the minority class is randomly oversampled, resulting in a fully balanced dataset. In addition, we applied stratified sampling to ensure that the proportions of classes were maintained during the cross-validation process. The only discrete features genderare also quite unbalanced.

## 查看所有连续变量的频率分布直方图
plot_histogram(ilpd,ggtheme = theme_mlr3())
ac543108d06e15d3f5ff20d99c02efb2.png
histogram

It can be seen that some indicator features are extremely right-skewed and contain several extreme values. To reduce the effect of outliers, and since some models assume normality of features, we logtransformed these variables.

Feature grouping

To delineate 目标the 特征relationship between and, we follow the 类别distributions analyzed特征 . First, we looked at discrete feature gender.

plot_bar(ilpd,by = 'diseased',ggtheme = theme_mlr3())
feb0fe87de6cb6ba984ed40d481d6fd0.png

In the "disease" category, there was a slightly higher proportion of men, but overall, the difference was not significant. In addition to this, as we mentioned earlier, gender imbalance can be observed in both categories .

To see the difference in continuous features, we compared the following boxplots, where the right-biased features have not been log-transformed.

## View bivariate continuous distribution based on `diseased`
plot_boxplot(ilpd,by = 'diseased')
bb60263b0214a26d3c5b2121edd024d5.png

You can see that except total_protein, for each feature, we get the difference between the medians of the two classes. It is worth noting that among the strongly right-biased features, the "disease" class contains far more extreme values ​​than the "no disease" class, probably because of its larger scale.

As you can see from the graph below, this effect is attenuated after log transformation. Furthermore, these features are more spread out in the "disease" class, as indicated by the length of the boxplots. Overall, these features appear to be related to the target, so it makes sense to use them for this task and model their relationship to the target.

log transformation of some features

ilpd_log = ilpd %>%
  mutate(
    # Log for features with skewed distributions
    alanine_transaminase = log(alanine_transaminase),
    total_bilirubin =log(total_bilirubin),
    alkaline_phosphatase = log(alkaline_phosphatase),
    aspartate_transaminase = log(aspartate_transaminase),
    direct_bilirubin = log(direct_bilirubin)
  )
plot_histogram(ilpd_log,ggtheme = theme_mlr3(),ncol = 3)
plot_boxplot(ilpd_log,by = 'diseased')
da8fff1dc4d62c73912ee433676d725d.png 07248c445afe2b67c350f63038e33ef2.png

It can be seen that logthe transformed data distribution is much improved.

related analysis

As we mentioned in the data description, some features are indirectly measured by another feature. This shows that they are highly correlated. Some of the model assumptions we want to compare are 独立features, 多重共线性or have problems. Therefore, we examined the correlation between features.

plot_correlation(ilpd)
b4a518d60c89308de46445a4ee77ad13.png

As can be seen, four of the pairs have very high correlation coefficients. Looking at these features, it's clear that they interact with each other. Since the complexity of the model should be minimized, and due to multicollinearity considerations, we decided to take only one of each pair of features. In deciding which features to keep, we selected those that were more specific and relevant to liver disease. Therefore, we chose albumin, not the ratio of albumin to globulin, nor the total amount of protein. The same point applies to using the amount of direct bilirubin instead of total bilirubin. Regarding aspartate aminotransferase and alanine aminotransferase, we did not notice any fundamental differences in the data for these two characteristics, so we chose aspartate aminotransferase arbitrarily.

final dataset

## Reducing, transforming and scaling dataset
ilpd = ilpd %>%
  select(-total_bilirubin, -alanine_transaminase, -total_protein,
         -albumin_globulin_ratio) %>%
  mutate(
    # Recode gender
    gender = as.numeric(ifelse(gender == "Female", 1, 0)),
     # Remove labels for class
    diseased = factor(ifelse(diseased == "yes", 1, 0)),
     # Log for features with skewed distributions
    alkaline_phosphatase = log(alkaline_phosphatase),
    aspartate_transaminase = log(aspartate_transaminase),
    direct_bilirubin = log(direct_bilirubin)
  )
## 标准化
po_scale = po("scale")
po_scale$param_set$values$affect_columns =
  selector_name(c("age", "direct_bilirubin", "alkaline_phosphatase",
                  "aspartate_transaminase", "albumin"))
task_liver = as_task_classif(ilpd_m, target = "diseased", positive = "1")
ilpd_f = po_scale$train(list(task_liver))[[1]]$data()

Finally, we performed on all continuous variable features 标准化, which is especially important for k-NN models. The table below shows the final dataset and the transformations we applied. Note: Unlike logarithmic or other transformations, scaling depends on the data itself. Scaling the data before it is split can lead to data leakage (see: Nature Reviews Genetics | Common pitfalls of applying machine learning in genomics ) because training and test set information is shared. Since data leaks can lead to higher performance, scaling should always be applied individually to each data split caused by the ML workflow. Therefore, we strongly recommend using it in this case PipeOpScale.

Learners and Tuning

First, we need to define one task, which contains the final dataset and some meta information. Also, we need to specify the positive class, because the package defaults to the first positive class as the positive class. The assignment of the positive class has implications for subsequent evaluations.

## Task definition
task_liver = as_task_classif(ilpd_f, target = "diseased", positive = "1")

Below we will evaluate the binary classification objectives of logistic regression, linear discriminant analysis(LDA), quadratic discriminant analysis(QDA), naive Bayes, k-nearest neighbour(k-NN), classification trees(CART) and .random forest

# detect overfitting
install.packages('e1071')
install.packages('kknn')
learners = list(
  learner_logreg = lrn("classif.log_reg", predict_type = "prob",
                       predict_sets = c("train", "test")),
  learner_lda = lrn("classif.lda", predict_type = "prob",
                    predict_sets = c("train", "test")),
  learner_qda = lrn("classif.qda", predict_type = "prob",
                    predict_sets = c("train", "test")),
  learner_nb = lrn("classif.naive_bayes", predict_type = "prob",
                   predict_sets = c("train", "test")),
  learner_knn = lrn("classif.kknn", scale = FALSE,
                    predict_type = "prob"),
  learner_rpart = lrn("classif.rpart",
                      predict_type = "prob"),
  learner_rf = lrn("classif.ranger", num.trees = 1000,
                   predict_type = "prob")
)

parameter tuning

To find the best hyperparameters, we use random search to better cover the hyperparameter space. We define the hyperparameters to tune. We only tune the hyperparameters of k-NN, , CARTand , as other methods have strong assumptions and serve as baselines.随机森林

For k-NN, we choose 3 as kthe lower bound (number of neighbors) and 50 as the upper bound. Too small k can lead to overfitting. We also tried different distance measures ( Manhattan distance1, Euclidean distance2) and kernels. For CART, we tuned the hyperparameters cp(complexity parameter) and minsplit(to try to split, the minimum number of observations in a node). cpControlled treesize: Small values ​​lead to overfitting, while large values ​​lead to underfitting. We also tune the 随机森林的parameters of the minimum size of terminal nodes and the number of candidate variables (from 1 to the number of features) randomly sampled at each split.

tune_ps_knn = ps(
  k = p_int(lower = 3, upper = 50), # Number of neighbors considered
  distance = p_dbl(lower = 1, upper = 3),
  kernel = p_fct(levels = c("rectangular", "gaussian", "rank", "optimal"))
)
tune_ps_rpart = ps(
  # Minimum number of observations that must exist in a node in order for a
  # split to be attempted
  minsplit = p_int(lower = 10, upper = 40),
  cp = p_dbl(lower = 0.001, upper = 0.1) # Complexity parameter
)
tune_ps_rf = ps(
  # Minimum size of terminal nodes
  min.node.size = p_int(lower = 10, upper = 50),
  # Number of variables randomly sampled as candidates at each split
  mtry = p_int(lower = 1, upper = 6)
)

The next step is to mlr3tuninginstantiate AutoTunerthe . We adopted the inner loop for nested resampling 5-fold交叉验证法. The number of times of evaluation was set to 100 times as a stopping criterion. We use AUCas evaluation metric, .

As mentioned before, we choose the perfectly balanced class due to the unbalanced target classes. By using mlr3pipelines, we can apply the benchmark function later.

# Oversampling minority class to get perfectly balanced classes
po_over = po("classbalancing", id = "oversample", adjust = "minor",
             reference = "minor", shuffle = FALSE, ratio = 416/167)
table(po_over$train(list(task_liver))$output$truth()) # Check class balance

# Learners with balanced/oversampled data
learners_bal = lapply(learners, function(x) {
  GraphLearner$new(po_scale %>>% po_over %>>% x)
})
lapply(learners_bal, function(x) x$predict_sets = c("train", "test"))

Model Fitting and Benchmarking

After defining the learner, choosing the inner method for nested resampling, and setting up the adjuster, we start to choose the outer resampling method. We chose a stratified 5-fold cross-validation method to preserve the distribution of the target variable, free from oversampling. However, it turns out that normal cross-validation without stratification also produces very similar results.

# 5-fold cross-validation
resampling_outer = rsmp(id = "cv", .key = "cv", folds = 5L)

# Stratification
task_liver$col_roles$stratum = task_liver$target_names

To rank the different learners and ultimately decide which one is best for the task at hand, we use benchmarking. The code block below executes our benchmark for all learners.

design = benchmark_grid(
  tasks = task_liver,
  learners = c(learners, learners_bal),
  resamplings = resampling_outer
)

bmr = benchmark(design, store_models = FALSE) ## 耗时较长

As mentioned above, we chose the stratified 5-fold cross-validation method . This means that performance is determined as the average of five model evaluations train-test-splitat 80% and 20%. Furthermore, the choice of performance metrics is crucial for ranking different learners. While each has its specific use case, we chose a AUCperformance metric that takes into account both sensitivity and specificity, which we also use for hyperparameter tuning.

We start by AUCcomparing all learners, with and without oversampling, and training and test data.

measures = list(
  msr("classif.auc", predict_sets = "train", id = "auc_train"),
  msr("classif.auc", id = "auc_test")
)

tab = bmr2$aggregate(measures)
tab_1 = tab[,c('learner_id','auc_train','auc_test')]
print(tab_1)
> print(tab_1)
                               learner_id auc_train  auc_test
 1:                       classif.log_reg 0.7548382 0.7485372
 2:                           classif.lda 0.7546522 0.7487159
 3:                           classif.qda 0.7683438 0.7441634
 4:                   classif.naive_bayes 0.7539374 0.7498427
 5:                    classif.kknn.tuned 0.8652143 0.7150679
 6:                   classif.rpart.tuned 0.7988561 0.6847818
 7:                  classif.ranger.tuned 0.9871615 0.7426650
 8:      scale.oversample.classif.log_reg 0.7540066 0.7497002
 9:          scale.oversample.classif.lda 0.7537952 0.7489675
10:          scale.oversample.classif.qda 0.7679012 0.7481963
11:  scale.oversample.classif.naive_bayes 0.7536208 0.7503436
12:   scale.oversample.classif.kknn.tuned 0.9982251 0.6870297
13:  scale.oversample.classif.rpart.tuned 0.8903927 0.6231100
14: scale.oversample.classif.ranger.tuned 1.0000000 0.7409655

From the above results, it can be seen that logistic regression, LDA, QDA and NB perform very similarly on training and test data with or without supersampling applied. On the other hand, k-NN, CART, and Random Forest predict much better on the training data, suggesting overfitting.

Furthermore, oversampling leaves AUCalmost no change in the performance of all learners.

The boxplots below show the AUCperformance of 5-fold cross-validation for all learners.

# boxplot of AUC values across the 5 folds
autoplot(bmr2, measure = msr("classif.auc"))
699e3f65d78751a95ee30a7c11048b7d.png
image-20220408223435031
autoplot(bmr2,type = "roc")+
  scale_color_discrete() +
  theme_bw()
c736af0aebc8edc2c4a4bcb51471da7d.png
image-20220410085535155

Subsequently, the sensitivity, specificity, false negative rate (FNR) and false positive rate (FPR) of each learner are output.

tab2 = bmr2$aggregate(msrs(c('classif.auc', 'classif.sensitivity','classif.specificity',
                            'classif.fnr', 'classif.fpr')))
tab2 = tab2[,c('learner_id','classif.auc','classif.sensitivity','classif.specificity',
               'classif.fnr', 'classif.fpr')]
print(tab2)
> print(tab2)
                               learner_id classif.auc classif.sensitivity
 1:                       classif.log_reg   0.7485372           0.8917097
 2:                           classif.lda   0.7487159           0.9037005
 3:                           classif.qda   0.7441634           0.6779116
 4:                   classif.naive_bayes   0.7498427           0.6250430
 5:                    classif.kknn.tuned   0.7180074           0.8509180
 6:                   classif.rpart.tuned   0.6987046           0.8679289
 7:                  classif.ranger.tuned   0.7506405           0.9447504
 8:      scale.oversample.classif.log_reg   0.7475678           0.6008893
 9:          scale.oversample.classif.lda   0.7489090           0.5841652
10:          scale.oversample.classif.qda   0.7431096           0.5529547
11:  scale.oversample.classif.naive_bayes   0.7494055           0.5505164
12:   scale.oversample.classif.kknn.tuned   0.6924480           0.6948078
13:  scale.oversample.classif.rpart.tuned   0.6753005           0.7090075
14: scale.oversample.classif.ranger.tuned   0.7393948           0.7427424
    classif.specificity classif.fnr classif.fpr
 1:           0.2516934  0.10829030   0.7483066
 2:           0.1855615  0.09629948   0.8144385
 3:           0.6946524  0.32208835   0.3053476
 4:           0.7488414  0.37495697   0.2511586
 5:           0.2581105  0.14908204   0.7418895
 6:           0.3108734  0.13207114   0.6891266
 7:           0.1554367  0.05524957   0.8445633
 8:           0.7663102  0.39911073   0.2336898
 9:           0.8023173  0.41583477   0.1976827
10:           0.8139037  0.44704532   0.1860963
11:           0.8381462  0.44948365   0.1618538
12:           0.5811052  0.30519220   0.4188948
13:           0.5449198  0.29099254   0.4550802
14:           0.5509804  0.25725760   0.4490196

It turns out that without oversampling, logistic regression, LDA, k-NN, CART, and random forests score high in sensitivity and quite low in specificity; on the other hand, QDA and naive Bayes Sterling scored relatively high on specificity, but not as high on sensitivity. By definition, high sensitivity (specificity) stems from a low false-negative (positive) rate, which is also reflected in the data.

Extract a single model

## 提取随机森林模型
bmr_rf = bmr2$clone(deep = TRUE)$filter(learner_ids = 'classif.ranger.tuned')
## ROC
autoplot(bmr_rf,type = "roc")+
  scale_color_discrete() +
  theme_bw()
## PRC
autoplot(bmr_rf, type = "prc")+
  scale_color_discrete() +
  theme_bw()
68d0c638ad3d45553328355bfc59fcf4.png
ROC
48a30fd00d7e95603d52c6bfa5b2d5d1.png
PRC

As to which learner works best, including whether oversampling should be used, much depends on the practical implications of sensitivity and specificity. In terms of practical importance, one of the two may outweigh the other many times over. Consider the example of the typical HIV rapid diagnostic test, where high sensitivity at the expense of low specificity can cause (unnecessary) shock but otherwise isn't dangerous, whereas low sensitivity is very dangerous. As is often the case, there is no "best model" in black and white. To recap, even with oversampling, none of our models performed well in terms of sensitivity and specificity. In our case, we need to think: what are the consequences of high specificity at the cost of low sensitivity, which means telling many patients with liver disease that they are healthy; and what are the consequences of high sensitivity at the cost of low specificity , which means telling many healthy patients that they have liver disease. In the absence of further topic-specific information, we can only state the learner that performs best on the specific performance metric chosen. As mentioned above, Random Forest based AUCon Random Forest performs the best. Also, Random Forest is the learner with the highest (lowest) sensitivity score FNR, while Naive Bayes is FPRthe learner with the best (lowest) specificity score.

However, our analysis is by no means exhaustive. At the feature level, while we have focused almost exclusively on the machine learning and statistical analysis aspects of our analysis, it is also possible to dig deeper into the actual topic (liver disease) and try to understand the variables and potential correlations and interactions more thoroughly sex. This may also mean that variables that have been removed are considered again. Additionally, feature engineering and data preprocessing can be performed on the dataset, such as using principal component analysis. Regarding hyperparameter tuning, consider using a larger hyperparameter space and evaluating the number of different hyperparameters. In addition, adjustments can also be applied to some learners that we label as baseline learners. Finally, there are many more classifiers out there, especially gradient boosting and support vector machines that can additionally be applied to this task and potentially yield better results.

reference

  • (mlr3gallery: Liver Patient Classification Based on Diagnostic Measures )(https://mlr3gallery.mlr-org.com/posts/2020-09-11-liver-patient-classification/)


5aae88bbdbb1540c97957154dcd2c8c7.png

Guess you like

Origin blog.csdn.net/weixin_45822007/article/details/124114043