深度强化学习-策略梯度算法深入理解

1 引言

深度强化学习-策略梯度算法推导博文中,采用了两种方法推导策略梯度算法,并给出了Reinforce算法的伪代码。可能会有小伙伴对策略梯度算法的形式比较疑惑,本文就带领大家剖析其中的原理,深入理解策略梯度算法的公式。本文主要参考了百度飞桨的视频Policy Gradient算法有兴趣的小伙伴可以看看,我觉得讲的非常透彻。

2 手写数字识别

我们先来看一下手写数字识别案列,采用LeNet网络,其输入为一张手写数字照片,输出为0-9每个数字对应的概率。LeNet网络结构不是本文介绍的重点,我们主要看损失函数部分。

假设网络的输入为数字5,标签为one-hot编码形式,即数字5对应概率值为1,其余为0,网络的输出如上图所示。对于分类问题,通常采用交叉熵(Cross Entropy) 损失函数

交叉熵:

H(p,q)=-\sum_{i=1}^{n}p(x_{i})log(q(x_{i}))

\large p\large q分别表示两个不同的分布,交叉熵可以衡量两个分布的差距,通过最小化交叉熵损失,就可以缩小两个分布之间的距离。将标签看作分布\large p,预测概率看作分布\large q,根据交叉熵公式,计算上图中的交叉熵

H=-(0\cdot log0.01+0\cdot 0.02+\cdots +1\cdot log0.8+\cdots +0\cdot log0.01)=-log0.8

将其作为损失进行梯度反传,更新网络参数,从而让预测概率分布更加接近标签。

3 策略梯度算法

看完手写数字识别案列后,回到策略梯度算法,单步损失和策略梯度的形式为

单步损失:

loss=-\gamma ^{t}G_{t}ln\pi _{\theta }(A_{t}\mid S_{t})

策略梯度:

\triangledown E_{\pi _{\theta }}\left [ G_{0} \right ]=E\left [ \sum_{t=0}^{+\infty }\gamma ^{t}G_{t}\triangledown ln\pi _{\theta }(A_{t}\mid S_{t}) \right ]

假设智能体的动作空间为离散形式,包括“左、停、右”三个动作,策略网络\large \pi _{\theta }(a\mid s)的输入为状态\large s_{t},输出为每个动作对应的概率。如下图所示

 其中预测概率为网络输出的概率分布,真实动作为智能体真正执行的动作,但是它并一定是一个正确的动作,无法作为标签。计算预测概率与真实动作之间的交叉熵,得到

H=-0\cdot log0.02-0\cdot log0.08-1\cdot log0.9=-log0.9

 发现它与单步损失中的ln\pi _{\theta }(A_{t}\mid S_{t})形式一致。由于真实动作不一定是正确的标签,所以加上累积奖励G_{t}作为权重。G_{t}越大,对应的损失越需要重视,反之G_{t}越小,对应的损失就不那么重要。\gamma^{t}可以认为是一个缩放因子,始终为正数,并不影响梯度的方向,因此可以忽略。综上,单步损失具体可以表示为

loss=-\gamma ^{t}G_{t}\sum Y_{i}{}'\cdot log(\pi _{\theta }(A_{t}\mid S_{t}))

其中Y_{i}{}'表示真实动作。对单步损失求梯度即为策略梯度的蒙特卡洛近似,通过梯度反传不断优化策略网络参数,让网络输出的概率分布接近累积回报较大的动作。

4 总结

本文利用离散动作模型剖析了策略梯度公式,发现它与分类模型类似。对于连续动作模型也是同样的道理,利用交叉熵衡量网络预测的概率分布与真实动作的概率分布,并采用累积奖励加权作为单步损失。对损失求梯度,然后沿着梯度的反方向不断更新策略网络参数,从而不断提升策略。

猜你喜欢

转载自blog.csdn.net/weixin_46133643/article/details/122284271