论文笔记:Generative Adversarial Imitation Learning

继续我们上篇博文模仿学习概述中的内容,上文中我们讲到,模仿学习中的逆向强化学习和对抗神经网络如出一辙,在本文中,我们就继续分析将对抗神经网络和逆向强化学习结合遇到的困难和所提出的解决方法。

背景

在上文模仿学习概述中讲过,模仿学习目前分为两个大类,一类是“行为克隆”,一类是“逆向强化学习”,前者可以看作是一种有监督的学习,根据输入的State,输出的Action,通过神经网络进行训练,这种方式弊病很多,效果不够理想。“逆向强化学习”就是在专家的输入是最优的假定下,去学习专家数据的分布,找到适合此类问题的强化学习的奖励函数。当然,这种方法也存在很多问题,比如逆向强化学习中也要进行强化学习的步骤,计算量偏大,针对同样的专家数据,有可能产生不同的奖励函数等。在这篇文章中,我们着重关注的是对抗神经网络和逆向强化学习的关系。

思路概述

我们再把前文的图放出来,看一看对抗神经网络和逆向强化学习的关系。

这种图很好说明了逆向强化学习和对抗神经网络的关系。我们知道,对抗神经网络常常用来生成以假乱真的图片,我们可以做一个类比,真图片就是逆向强化学习中的专家数据,假图片就是逆向强化学习中Actor生成的数据,在GAN中,discriminator用来去区分真假图片,在逆向强化学习中,discriminator用来区分专家数据和Actor数据。这个关系搞明白了,下面就是利用这两个方面的知识来实现这个想法。

逆向强化学习的目标

要训练一个网络,就要找到任务目标的loss function,为了确定这个lost function,我们先看看逆向强化学习最终要优化的目标是什么,这里使用的是maximum causal entropy IRL,具体可见 Maximum entropy inverse reinforcement learning(这文章我没看明白):

\underset{c \in \mathcal{C}}{\operatorname{maximize}}\left(\min _{\pi \in \Pi}-H(\pi)+\mathbb{E}_{\pi}[c(s, a)]\right)-\mathbb{E}_{\pi_{E}}[c(s, a)]

其中H(\pi) \triangleq \mathbb{E}_{\pi}[-\log \pi(a | s)]

把这个逆强化学习的过程变成正向的强化学习,我们的目标便是下面的公式:

\mathrm{RL}(c)=\underset{\pi \in \Pi}{\arg \min }-H(\pi)+\mathbb{E}_{\pi}[c(s, a)]

下面我们定义后面逆向强化学习的基本目标

\operatorname{IRL}_{\psi}\left(\pi_{E}\right)=\underset{c \in \mathbb{R}^{s \times A}}{\arg \max }-\psi(c)+\left(\min _{\pi \in \Pi}-H(\pi)+\mathbb{E}_{\pi}[c(s, a)]\right)-\mathbb{E}_{\pi_{E}}[c(s, a)]

我们的目标。便是寻找到一个cost function c,使得专家的行为始终以-\psi(c)的大小高于其他的策略。

本文的后面就是一堆又一堆的恼人公式,挑点能看懂的说吧。

占有率度量

占有率度量(occupancy measure)这个东西一时半会很难理解它的意思,按照原文的说法,"the occupancy measure can be interpreted as the distribution of state-action pairs that an agent encounters when navigating the environment with policy \pi"。也就是说占有率度量衡量的是某个策略下状态-动作对的分布。有如下公式可以证明:

\mathbb{E}_{\pi}[c(s, a)]=\sum_{s, a} \rho_{\pi}(s, a) c(s, a)

这个公式说明了特定state-action对的cost与总体cost期望的关系。

算法基于的公式

算法的目标是寻找下面表达式的鞍点(pi,D)

\mathbb{E}_{\pi}[\log (D(s, a))]+\mathbb{E}_{\pi_{E}}[\log (1-D(s, a))]-\lambda H(\pi)

算法

发布了85 篇原创文章 · 获赞 100 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/caozixuan98724/article/details/103840170