【论文阅读】Sepsis World Model A MIMIC-based OpenAI Gym World Model Simulator for Sepsis Treatment

Sepsis World Model: A MIMIC-based OpenAI Gym “World Model” Simulator for Sepsis Treatment

Author Kiani et al.
Comments This project was done as a class project for CS221 at Stanford University(Not so good)
Year 2019
Tags deep reinforcement learning, sepsis treatment, predict next state, DQN

1 Introduction and Motivation / Abstract

有了如MIMICⅢ的医学数据集之后,使用deep reinforcement learning为sepsis治疗寻找最优策略变得普遍。

然而,面临的一种挑战:我们所知道的states只是整个状态空间中的一小部分(其实也就是limited data的问题),而且这些数据也会包含噪声。

现有工作的解决方法是:使用带有重要性采样的off-policy评估策略;训练随机策略和其他评估技术等。

我们的方案:使用”world model"(这个概念最早是Ha and Schmidhuber在2018年提出的)方法创建一个模拟器,目标是用于在给定病人的当前state和treatment action的情况下预测下一个state。在这个方案中,需要让模拟器从EHR data中学习到潜在的、含有更少噪声的表征特征。

使用的数据:MIMIC

模型的结构VAE(Variational Auto-Encoder)+MDN-RNN(Mixture Density Network combined with a RNN)。为了减少噪声的影响,在模拟过程中从下一步生成的分布中采样,并通过控制类似于Ha等人提出的”温度“变量将不确定性引入模拟器。

最终通过比较测试环境输出和真实EHR数据的相似性,评估模型的表现;通过使用deep Q-Learning学习脓毒症的治疗的现实政策,评估其可行性。

2 Approach

2.1 Dataset Overview and Preprocessing

每个状态有46个经过标准化的features。

action用0-24之间的离散数值表示。

最终的目标是在基于关于患者在某个特定时间步的信息给出一个treatment action建议,确保患者存活。因此,构建”State Model“,在给定当前state和采取的action的情况下预测下一个state。

2.2 Simulator Models

baseline:(baseline没有使用VAE,没有使用MDN-RNN来模拟state的不确定性,因此会对噪声数据点过拟合)

①模拟next state的state model(用RNN);

②模拟end of state的termination model(用RNN);

③模拟outcome prediction的outcome model(用RNN)。

本文的model:VAE + MDN-RNN。

VAE把输入的特征从46维降至30维。训练VAE的过程中将原始输入和降维后的输出的MSE控制到最小。

在这里插入图片描述

VAE输出的数据作为MDN-RNN的输入。MDN-RNN预测下一个state的概率分布。

在这里插入图片描述

对于本文,World Model的整体结构如下:

在这里插入图片描述

除了本文用于预测next state的模型,World Model中还有一个使用DQN算法训练学习的最优策略、一个由专家提出的策略。

VAE对数据进行处理后,将特征输出到前面提到的三个RNN,然后又对VAE+RNN、MDN+RNN、VAE+MDN+RNN的表现进行比较,分析能否对baseline做提升,如果可以,又是哪些提升是有效的。

2.3 State Model

state model是一个RNN。输入:VAE处理后的features(30×10)、当前时间步的action value。输出:表示下一个state的feature。

在这里插入图片描述

2.4 Episode Termination Model

此模型用于检测episode transitions(情节转换)。

两种互斥的转换:①终止episode;②继续episode。

输入:VAE处理后的features(30×10)、当前时间步的action value(1)、step number feature(1)(添加step number是为了计算episode的长度)

输出:布尔值,表示终止还是继续。

在这里插入图片描述

2.5 Episode Outcome Model

此模型用于预测两种互斥的结果:①death;②出院(release from hospital)。预测出每个episode的结果之后用于在环境中决定奖励值。

输入:与 Episode Termination Model一样。

输出:布尔值,表示死亡还是出院。

在这里插入图片描述

2.6 DQN Agents Model

为了评估本文提出的这些模拟器,利用openAI中的baseline离线算法(基于不同的架构,如:baseline、VAE、MDN、VAE+MDN)训练三个agent,并进行封装,使得提供一个state和采取的action就可以产生state和reward。

DQN算法用于在模拟环境中学习最优策略。

奖励函数的定义:(提出了三个,后面会分别与专家的policy对比测试)

在这里插入图片描述

第三个策略是为了同时确保中间奖励不会掩盖最终奖励,并提供一些指导性反馈,以在每个单独的时间步向正确的方向修正策略。

3 Results and Analysis

3.1 Autoencoder and VAE

利用训练好的VAE得到降维后的特征,用这些降维后的特征与原来的state特征进行比较,发现预测效果不错。

怎么比较的呢?是用这些特征都去预测?

在这里插入图片描述

在本文中使用VAE不是为了很好地匹配所有state特征,而是主要为了降噪、在原始数据中挖掘隐藏特征。

3.2 Simulator: State, Termination, and Outcome

对三个model的训练结果进行分析,发现RNN和RNN+VAE的表现很接近,似乎有无VAE表现都没有明显区别。不过作者给出的解释是:尽管VAE可能会丢失一些原始信息,但仍能捕捉到重要信息(以完成预测),至少能取得与没有用VAE处理过的原始数据所做的预测相近的表现,那也是可以认可的。

但是其实有一个问题,文章声称使用VAE能减少数据中的噪声,但是噪声从何而来、又如何得知噪声已经减小了呢?

在这里插入图片描述

图7绘制AE、VAE、VAE+MDN、MDN的SOFA和SpO2状态特征的模拟投影。在这些图中,即使模型错误地预测了某个状态,它也会收到该状态的正确版本作为预测下一步的输入。

结果表明,MDN对预测状态赋予了更大的方差,与预期一致。似乎比AE和VAE学习得更多,而不是简单地保持预测不变,直到一个旧的状态被添加回来作为输入(从左边两幅图中蓝色之后的黄色运动中看出这种趋势)。

MDN+VAE比MDN本身的方差更大,与预期一致。

可是方差大能代表什么?

在这里插入图片描述

3.3 Analysis of Rollout on Physician’s policy

尝试对模拟器进行关于医生提出的策略的可视化检查。

具体来说,以患者的起始状态初始化模型,然后执行医生在每个患者身上执行的实际系列动作,并在发作的长度上可视化状态特征。这里只访问模型作生成的状态作为历史,这代表了试图训练一个agent通过探索来学习政策时可能获得的东西,因为我们不一定能访问数据集中无限大小和连续的状态-动作空间。

结果表明,RNN本身产生平滑的曲线,而基于MDN的模型中则有不断变化的趋势。我们认为这可能是由于RNN本身无法完全捕捉输出的动态方差,从而收敛于寻找潜在下一状态的"均值"。这就是我们首先引入MDN-RNN的原因,这样模型可以预测下一个状态来自的一组分布,并捕获状态必须来自其中一个分布的想法。

的确,MDN+RNN似乎能更好地捕捉整个事件的方差,并遵循一般趋势,尽管这带来了有时错误分布的风险。

在MDN中,虽然个别步骤可能与上一步有较大的方差,但分布通常会在下一次预测中修正自己回到一个更稳定的值。

总体而言,MDN和VAE似乎成功地建模了下一个状态的方差和分布。

在这里插入图片描述

3.4 Normalized Trajectory Means

计算一个定量的度量,以便测量模拟器结果的误差。

文章提出了归一化轨迹均值度量(Normalized Trajectory Mean metric),它根据特定的state model计算每个特征在所有rollout中的均值。对不同的特征测量了这个值,如图9所示。

在这里插入图片描述

这一指标说明了模型对每个指标的校准程度。它可以作为模型性能的检查,并为未来改进的优先顺序提供方向。

可以看到真值均值(右)比简单的RNN均值(左)更类似于MDN均值(中)。这证实了前文的猜想,尽管MDN模型在每一步都有更多的方差,但通常被调节为稳定的整体和正确的主要变化,防止发散,而光滑的RNN模型可能会发散。

3.5 Evaluation on OpenAI Baseline Learned Policies

最终的目标是学习一个策略来改善患者的结果,因此使用DQN算法来评估环境。虽然没有确切的"标签"或数量来衡量我们学习的策略(除了临床验证)的临床效果,但与真实数据集中的长度、奖励和动作的定性比较可以评估模拟器如何模拟治疗过程。图10显示了专家在action、reward和时间长度上的策略分布,我们将其与我们的政策进行比较。

在这里插入图片描述

在我们的模拟环境上回放专家的动作后,我们比较了真实世界和模拟世界之间的episode长度、奖励和动作的分布。

对于使用奖励公式(1)的情况,模拟结果过于极端,策略围绕一个行动和一个非常短的episode长度。使用基于MDN的模拟器实现了稍微更真实的状态轨迹。MDN模型学习每个特征的分布,在采样时提供更具代表性的状态特征集合。然而,与专家的策略相比,这一学习的策略是不现实的,环境过度拟合了一组小的干预措施及其积极的结果。

在这里插入图片描述

使用奖励公式(2)的情况:episode的长度仍然不切实际。同时这种奖励的制定似乎也没有给出更加多元的政策。

在这里插入图片描述

使用奖励公式(3)的情况:策略分布更贴近实际。这可能是由于模型必须在每个时间步选择一个动作来优化一个对当下重要的特定值,因此有激励选择一个对该状态最有效的最优特定action。在另外两种reward方案中,需要的动作是在情节结束时对某事进行优化,不具有时间依赖性,这导致agent每次都会预测相同的动作。

然而,episode的长度仍然偏短,这可能表明termination模型是过拟合的。

在这里插入图片描述

4 Conclusion

本文的工作:

①使用VAE和MDN两种方法,比一个简单的RNN更能建模出脓毒症患者状态分布。

②在这些模拟器(也就是VAE和MDN)的基础上建立一个模型,并且可以迭代各种奖励函数来建模患者的治疗轨迹。

未来的工作:

①优化状态/终止/结果模型的结构。

②细化reward和uncertainty函数。

猜你喜欢

转载自blog.csdn.net/Mocode/article/details/128215734
今日推荐