【Paper Reading】Sepsis World Model A MIMIC-based OpenAI Gym World Model Simulator for Sepsis Treatment

Sepsis World Model: A MIMIC-based OpenAI Gym “World Model” Simulator for Sepsis Treatment

Author Kiani et al.
Comments This project was done as a class project for CS221 at Stanford University(Not so good)
Year 2019
Tags deep reinforcement learning, sepsis treatment, predict next state, DQN

1 Introduction and Motivation / Abstract

With medical datasets such as MIMIC III, it becomes common to use deep reinforcement learning to find optimal strategies for sepsis treatment.

However, there is a challenge: the states we know are only a small part of the entire state space (in fact, it is the problem of limited data), and these data will also contain noise.

Solutions to existing work are: using off-policy evaluation strategies with importance sampling; training random strategies and other evaluation techniques, etc.

Our solution : use the "world model" (this concept was first proposed by Ha and Schmidhuber in 2018) to create a simulator with the goal of predicting the next state given the patient's current state and treatment action . In this scheme, it is necessary for the simulator to learn latent, less noisy representations from the EHR data.

Data used : MIMIC

The structure of the model : VAE(Variational Auto-Encoder) + MDN-RNN(Mixture Density Network combined with a RNN) . To reduce the effect of noise, samples are taken during the simulation from the distribution generated in the next step, and uncertainty is introduced into the simulator by controlling the "temperature" variable similar to that proposed by Ha et al.

Finally, the performance of the model is evaluated by comparing the similarity between the test environment output and the real EHR data; and its feasibility is evaluated by using deep Q-Learning to learn a realistic policy for the treatment of sepsis.

2 Approach

2.1 Dataset Overview and Preprocessing

Each state has 46 normalized features.

Action is represented by a discrete value between 0-24.

The ultimate goal is to recommend a treatment action based on the information about the patient at a specific time step to ensure the survival of the patient. Therefore, build a "State Model" that predicts the next state given the current state and the actions taken.

2.2 Simulator Models

baseline: (baseline does not use VAE, and does not use MDN-RNN to simulate the uncertainty of the state, so it will overfit the noise data points)

①Simulate the state model of the next state (using RNN);

②Simulate the termination model of the end of state (using RNN);

③Simulate the outcome model of outcome prediction (using RNN).

The model of this article: VAE + MDN-RNN.

VAE reduces the input features from 46 dimensions to 30 dimensions. In the process of training VAE, the MSE of the original input and the output after dimension reduction is controlled to the minimum.

insert image description here

The data output by VAE is used as the input of MDN-RNN. MDN-RNN predicts the probability distribution of the next state.

insert image description here

For this article, the overall structure of the World Model is as follows:

insert image description here

In addition to the model used to predict the next state in this article, there is also an optimal strategy for training and learning using the DQN algorithm in the World Model, and a strategy proposed by experts.

After VAE processes the data, it outputs the features to the three RNNs mentioned above, and then compares the performance of VAE+RNN, MDN+RNN, and VAE+MDN+RNN to analyze whether the baseline can be improved, and if so, which improvements are effective.

2.3 State Model

The state model is an RNN. Input: features after VAE processing (30×10), action value of the current time step. Output: Represents the feature of the next state.

insert image description here

2.4 Episode Termination Model

This model is used to detect episode transitions.

Two mutually exclusive conversions: ① terminate the episode; ② continue the episode.

Input: features after VAE processing (30×10), action value of the current time step (1), step number feature (1) (adding step number is to calculate the length of episode)

Output: Boolean indicating whether to terminate or continue.

insert image description here

2.5 Episode Outcome Model

This model is used to predict two mutually exclusive outcomes: ①death; ②release from hospital. After predicting the outcome of each episode, it is used to determine the reward value in the environment.

Input: Same as Episode Termination Model.

Output: Boolean indicating death or discharge.

insert image description here

2.6 DQN Agents Model

In order to evaluate these simulators proposed in this paper, use the baseline offline algorithm in openAI (based on different architectures, such as: baseline, VAE, MDN, VAE+MDN) to train three agents and encapsulate them so that a state and an action taken can generate state and reward.

The DQN algorithm is used to learn optimal policies in a simulated environment.

Definition of reward function: (Three are proposed, which will be compared with experts' policies later)

insert image description here

The third strategy is to simultaneously ensure that intermediate rewards do not overshadow the final reward and provide some guiding feedback to correct the policy at each individual time step in the right direction.

3 Results and Analysis

3.1 Autoencoder and VAE

Use the trained VAE to obtain the reduced-dimensional features, and compare these reduced-dimensional features with the original state features, and find that the prediction effect is good.

How do you compare? Are these features used to predict?

insert image description here

The use of VAE in this paper is not to match all state features well, but mainly for noise reduction and mining hidden features in the original data.

3.2 Simulator: State, Termination, and Outcome

Analyzing the training results of the three models, it is found that the performance of RNN and RNN+VAE is very close, and it seems that there is no significant difference in the performance of VAE. However, the explanation given by the author is: Although VAE may lose some original information, it can still capture important information (to complete the prediction), at least it can achieve a similar performance to the prediction made by the original data that has not been processed by VAE, which is acceptable.

But there is actually a problem. The article claims that using VAE can reduce the noise in the data, but where does the noise come from and how do you know that the noise has been reduced?

insert image description here

Figure 7 plots simulated projections of SOFA and SpO2 state characteristics for AE, VAE, VAE+MDN, MDN. In these graphs, even if the model predicts a state incorrectly, it receives the correct version of that state as input to predict the next step.

The results show that MDN assigns a larger variance to the predicted state, as expected. Seems to learn more than AEs and VAEs, rather than simply keeping the predictions until an old state is added back as input (see this trend in the yellow movement after the blue in the left two images).

MDN+VAE has larger variance than MDN itself, as expected.

But what can a large variance mean?

insert image description here

3.3 Analysis of Rollout on Physician’s policy

Try to perform a visual inspection of the simulator on the strategy proposed by the doctor.

Specifically, the model is initialized with the patient's starting state, then the actual sequence of actions performed by the doctor on each patient is performed, and the state characteristics are visualized over the length of the episode. Here only the state generated by the model is accessed as history, which represents what is possible when trying to train an agent to learn policy through exploration, since we do not necessarily have access to the infinitely sized and continuous state-action space in the dataset.

The results show that the RNN itself produces smooth curves, while the MDN-based models have a constantly changing trend. We think this may be due to the inability of the RNN itself to fully capture the dynamic variance of the output, thus converging on finding the "mean" of the potential next state. This is why we introduced the MDN-RNN in the first place, so that the model can predict the set of distributions the next state will come from, and capture the idea that the state must come from one of these distributions.

Indeed, MDN+RNN seems to better capture the variance of the entire event and follow a general trend, although this carries the risk of sometimes misdistribution.

In MDN, although individual steps may have large variances from the previous step, the distribution usually corrects itself back to a more stable value in the next prediction.

Overall, MDN and VAE seem to successfully model the variance and distribution of the next state.

insert image description here

3.4 Normalized Trajectory Means

Computes a quantitative metric for measuring the error of the simulator results.

The article proposes a normalized trajectory mean metric (Normalized Trajectory Mean metric), which calculates the mean of each feature in all rollouts according to a specific state model. This value was measured for different features, as shown in Fig. 9.

insert image description here

This metric tells how well the model is calibrated for each metric. It serves as a check on model performance and provides direction for prioritizing future improvements.

It can be seen that the ground truth mean (right) is more similar to the MDN mean (middle) than the simple RNN mean (left). This confirms the previous conjecture that although MDN models have more variance at each step, they are generally tuned to stabilize the ensemble and correct principal changes, preventing divergence, whereas smooth RNN models may diverge.

3.5 Evaluation on OpenAI Baseline Learned Policies

The ultimate goal is to learn a policy to improve patient outcomes, thus using the DQN algorithm to evaluate the environment. While there are no exact "labels" or quantities to measure the clinical effectiveness of our learned policies (other than clinical validation), qualitative comparisons to length, reward, and actions in real datasets allow for an assessment of how well the simulator mimics the treatment process. Figure 10 shows the expert's policy distribution over action, reward, and time length, which we compare with our policy.

insert image description here

After replaying the expert's actions on our simulated environment, we compare the distribution of episode lengths, rewards, and actions between the real and simulated worlds.

For the case of using the reward formula (1), the simulation results are too extreme, and the policy revolves around one action and a very short episode length. Slightly more realistic state trajectories are achieved using an MDN-based simulator. The MDN model learns the distribution of each feature, providing a more representative set of state features when sampling. However, this learned strategy is unrealistic and the environment overfits a small set of interventions and their positive outcomes compared to the expert's strategy.

insert image description here

The case of using the reward formula (2): the length of the episode is still impractical. At the same time, the formulation of this kind of reward does not seem to give a more diverse policy.

insert image description here

The case of using the reward formula (3): the policy distribution is closer to reality. This may be due to the fact that the model has to choose an action at each time step to optimize a specific value that is important to the moment, so there is an incentive to choose an optimal specific action that is most effective for that state. In the other two reward schemes, the required action is to optimize something at the end of the episode, which has no time dependence, which leads the agent to predict the same action every time.

However, the length of episodes is still short, which may indicate that the termination model is overfitting.

insert image description here

4 Conclusion

The work of this paper:

①Using two methods of VAE and MDN, the state distribution of sepsis patients can be modeled better than a simple RNN.

② Build a model on top of these simulators (that is, VAE and MDN), and can iterate various reward functions to model the patient's treatment trajectory.

future career:

① Optimize the structure of the state/termination/result model.

② Refine the reward and uncertainty functions.

Guess you like

Origin blog.csdn.net/Mocode/article/details/128215734