Causal inference dowhy - counterfactual analysis in medical cases

0x01. Background

In this example, we know 三个观察变量的因果结构that we want to get counterfactual questions like "What would have happened if I had followed the doctor's advice differently?"

More specifically, Alice, who suffers from severe dry eyes, decided to use the telemedicine platform because she was unable to see an ophthalmologist where she lived. She judged whether Alice had a rare allergy by reporting her medical history, and the platform finally recommended two possible eye drops for her with slightly different ingredients (“Option 1” and “Option 2”).

Alice does a quick search online and she finds that option 1 has a lot of positive reviews. Still, she decided to use the second method because her mother had used it in the past with good results. After a few days, Alice's vision improved and her symptoms began to disappear. However, she was curious what would happen if she used the very popular option 1, or even did nothing at all.

The platform provides users with 反事实问题的可能性results as long as they report the options they follow.

0x02. Analog data

We describe the SCM framework as follows, fp 1 , p 2 f_{p1,p2}fp 1 , p 2is the noise added to the model, expressed as: V ision = V n + fp 1 , p 2 ( T reatment , Condition ) Vision=V_n+f_{p1,p2}(Treatment, Condition)Vision=Vn+fp 1 , p 2(Treatment,C o n d i t i o n ) . Our original features of the three observed variablesNT , NC , NV N_T, N_C, N_VNT,NC,NVAdd noise for sampling, and the Vision of the target variable is NV N_VNVplus the noise of its input node.

T r e a t m e n t = N T Treatment=N_T Treatment=NTThe probability of ~0, 1 or 2 is 33% respectively: 33% of users do nothing, 33% of users choose option 1, 33% of users choose option 2. This is independent of whether a patient has a rare disease or not.

C o n d i t i o n = N C Condition=N_C Condition=NC~ Bernoulli (0.01): Whether the patient has a rare disease.

V i s i o n = N V + f p 1 , p 2 ( T r e a t m e n t , C o n d i t i o n ) = N V − P 1 ( 1 − C o n d i t i o n ) ( 1 − T r e a t m e n t ) ( 2 − T r e a t m e n t ) + 2 P 2 ( 1 − C o n d i t o n ) T r e a t m e n t ( 2 − T r e a t m e n t ) + P 2 ( 1 − C o n d i t i o n ) ( 3 − T r e a t m e n t ) ( 1 − T r e a t m e n t ) T r e a t m e n t − 2 P 2 C o n d i t o n T r e a t m e n t ( 2 − T r e a t m e n t ) − P 2 C o n d i t o n ( 3 − T r e a t m e n t ) ( 1 − T r e a t m e n t ) T r e a t m e n t Vision=N_V+f_{p1,p2}(Treatment, Condition)=N_V-P_1(1-Condition)(1-Treatment)(2-Treatment)+2P_2(1-Conditon)Treatment(2-Treatment)+P_2(1-Condition)(3-Treatment)(1-Treatment)Treatment-2P_2ConditonTreatment(2-Treatment)-P_2Conditon(3-Treatment)(1- Treatment)Treatment Vision=NV+fp 1 , p 2(Treatment,Condition)=NVP1(1Condition)(1Treatment)(2Treatment)+2P2(1Conditon)Treatment(2Treatment)+P2(1Condition)(3Treatment)(1Treatment)Treatment2P2ConditonTreatment(2Treatment)P2Conditon(3Treatment)(1Treatment)Treatment

P 1 P_1 P1is a constant, and in the rare case that the patient does not have a decrease in raw vision, he is not taking any medication.

P 2 P_2 P2is a constant, and depending on whether the patient has the disease or not and the type of drops they will be using, raw vision will increase or decrease accordingly. more specifically:

If Condition = 0 and Treatment = 1 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 2 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 1 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 2 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 0 then Vision = N_V - P_1

elif Condition = 1 and Treatment = 0 then Vision = N_V - P3

For such rare events, such as conditional (condition = 1, with a low probability of 1%), it is necessary to have a large number of samples to train the model in order to accurately reflect these rare events. That's why we use 10000 samples here to generate the patient database.

Generate normal data:

from scipy.stats import bernoulli, norm, uniform
import numpy as np
from random import randint
import pandas as pd

n_unobserved = 10000
unobserved_data = {
    
    
   'N_T': np.array([randint(0, 2) for p in range(n_unobserved)]),
   'N_vision': np.random.uniform(0.4, 0.6, size=(n_unobserved,)),
   'N_C': bernoulli.rvs(0.01, size=n_unobserved)
}
P_1 = 0.2
P_2 = 0.15

def create_observed_medical_data(unobserved_data, name):
    observed_medical_data = {
    
    }
    observed_medical_data['Condition'] = unobserved_data['N_C']
    observed_medical_data['Treatment'] = unobserved_data['N_T']
    observed_medical_data['Vision'] = unobserved_data['N_vision'] + (-P_1)*(1 - observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (2*P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data['Treatment']) + 0*(observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (-2*P_2)*(unobserved_data['N_C'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (-P_2)*(observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data['Treatment'])
    dfs = pd.DataFrame(observed_medical_data)
    dfs.to_csv(name, index=False)
    return pd.DataFrame(observed_medical_data)

medical_data = create_observed_medical_data(unobserved_data, 'patients_database.csv')

Generate exception data:


num_samples = 1
original_vision = np.random.uniform(0.4, 0.6, size=num_samples)
def generate_specific_patient_data(num_samples):
    return create_observed_medical_data({
    
    
    'N_T': np.full((num_samples,), 2),
    'N_C': bernoulli.rvs(1, size=num_samples),
    'N_vision': original_vision,
})

specific_patient_data = generate_specific_patient_data(num_samples, "newly_come_patients")

0x03. Read normal data

We have a database consisting of three observation variables: a continuous variable from 0 to 1 denoting vision quality (“Vision”), and a binary variable denoting whether a patient has a rare disease (i.e. allergy) (“condition” ), and a categorical variable ("Treatment") that can take three values ​​(0: "do nothing", 1: "option 1" or 2: "option 2"). Data are as follows:

import pandas as pd

medical_data = pd.read_csv('patients_database.csv')
medical_data.head()

Data are as follows:

Condition Treatment Vision
0 0 2 0.223475
1 0 2 0.197306
2 0 0 0.101252
3 0 1 0.703056
4 0 0 0.020249
medical_data.iloc[0:100].plot(figsize=(15, 10))

insert image description here
The dataset reflects patients' vision after taking one of three treatment options, based on whether they had the rare disease. Note that the dataset has no information about the patients' raw vision before treatment (i.e. the noise of the vision variable). As we will see below, whenever we have post-nonlinear models (such as ANMs), the noisy parts of vision are recovered by counterfactual algorithms.

0x04. Modeling

We know that the "treatment" node and the "condition" node lead to vision, but we don't know the structural causal model. However, we can learn about it from observational data. We assume that this graph correctly represents causality, and we assume that there are no hidden confounding factors (causal adequacy). Given a graph and data, we can fit a causal model and begin to answer counterfactual questions.

import networkx as nx
import dowhy.gcm as gcm

causal_model = gcm.InvertibleStructuralCausalModel(nx.DiGraph([('Treatment', 'Vision'), ('Condition', 'Vision')]))
gcm.auto.assign_causal_mechanisms(causal_model, medical_data)

gcm.util.plot(causal_model.graph)

gcm.fit(causal_model, medical_data)

insert image description here

0x05. Read abnormal data

specific_patient_data = pd.read_csv('newly_come_patients.csv')
specific_patient_data.head()

The output is as follows:

Condition Treatment Vision
0 1 2 0.857103

0x06. Answer Alice's counterfactual question

If we want to check the outcome of a hypothesis, if an event didn't happen, or happened in a different way, we use so-called counterfactual logic based on structural causal models. Considering -- we know Alice's treatment option is the second. Alice suffers from a rare allergy (condition=1). After treatment 2, Alice's visual acuity was 0.78 (vision = 0.78). - We are able to recover noise based on the learned structural causal model.

We can now check the counterfactual results of her vision if the healing nodes are different. Below, we look at the counterfactual value of Alice's vision if she does not receive any treatment (treatment = 0) and if she takes other eye drops (treatment = 1).

counterfactual_data1 = gcm.counterfactual_samples(causal_model,
                                                  {
    
    'Treatment': lambda x: 1},
                                                  observed_data = specific_patient_data)

counterfactual_data2 = gcm.counterfactual_samples(causal_model,
                                                  {
    
    'Treatment': lambda x: 0},
                                                  observed_data = specific_patient_data)


import matplotlib.pyplot as plt

df_plot2 = pd.DataFrame()
df_plot2['Vision after option 2'] = specific_patient_data['Vision']
df_plot2['Counterfactual vision (option 1)'] = counterfactual_data1['Vision']
df_plot2['Counterfactual vision (No treatment)'] = counterfactual_data2['Vision']

df_plot2.plot.bar(title="Counterfactual outputs")
plt.xlabel('Alice')
plt.ylabel('Eyesight quality')
plt.legend()

The effect is as follows:
insert image description here
what we see here is that if Alice chooses option 1, her eyesight will be worse than option 2. Therefore, she realized that the rare condition (condition = 1) she reported in her medical history could cause an allergic reaction to popular option 1. Alice can also see that if she does not choose any of the recommended options, her eyesight will be worse than if she chooses option 2 (the variable Vision results in a smaller relative value).

Guess you like

Origin blog.csdn.net/l8947943/article/details/129750849