大模型强化学习:RLHF、PPO

RL:什么是PPO

这一节主要参考:

https://towardsdatascience.com/proximal-policy-optimization-ppo-explained-abed1952457b​towardsdatascience.com/proximal-policy-optimization-ppo-explained-abed1952457b

首先我们要理解什么是策略梯度(Policy Gradient)。

以下推导过程包含很多个人理解,并不很严谨,有错误欢迎指出!

粗糙理解RL的过程就是,让智能体在一个状态S下选择动作A,然后获得收益R,然后我们希望优化选择动作的策略,使得总体收益的期望最大。因为搜索空间很大,我们利用模型的预测结果决策,同时为了不让模型陷入局部最优而按蒙特卡洛方式一定比例随机游走,在这个过程中得到每个state-action对应的reward作为新的训练样本,即所谓的探索和利用(Exploration and Exploitation)过程。

对一组模型参数,可以得到一组轨迹序列的概率分布 

对一条由多个状态动作对组成的轨迹 � ,我们得到reward期望:  ,其中 � 是0-1的折扣因子(因为远期奖励相对不重要), �� 是不同时间步的reward。

于是目标函数: 

我们想最大化目标所以用梯度上升: 

求解梯度过程:

为什么要写成期望呢?因为实际计算我们用多次采样的平均值就能作为期望的近似值。

如果把轨迹理解为输出的句子,那么 ���(�(�;�)) 对应在文本生成里就是给定一个输入文本 � ,得到输出文本 � 的概率 �(�|�)=�(�1|�)�(�2|�,�1)⋯�(��|�,�1,⋯,��−1) 。

而PPO算法的流程如下:

PPO

相比策略梯度的区别在于引入了一项对输出token序列约束,即上文提到的KL散度惩罚项:

这里的超参  (图中的  )也是启发式方法设计的,如果KL散度超过一定阈值,就使 翻倍以修正输出;反之小于一定阈值就减半我们牺牲了一些数学上的严谨性来支持实用主义!


相关资料

首先是相关论文。

关于近期的相关论文,直接看原博客下面的链接就行,我这里截个图作memo:

其次是代码。

OpenAI在2019年就做了一个基于tf做RLHF的项目:https://github.com/openai/lm-human-preferences

用torch的也有不少,这些都是Hug的博客里推的项目:

建议观望一下,因为:

Both TRLX and RL4LMs are under heavy further development, so expect more features beyond these soon.

毕竟大家都不希望好好的用到一半API都改完了吧(说的就是你tf),当然huggingface自己也好不到哪里去。

此外还有一些相关的项目,比如chatgpt的google插件:https://github.com/wong2/chat-gpt-google-extension

另外,我个人实践用过的chatgpt API(非官方)里面,能用的有:GitHub - acheong08/ChatGPT: Lightweight package for interacting with ChatGPT's API by OpenAI. Uses reverse engineered official API.

不过貌似OpenAI加了cloudflare保护之后这些工具都失效了,再观察一下吧。


杂谈:Why RL?

看完hug的博客让我有一种“就这”的感觉。如果只是在finetune阶段引入了RL,也没法说服我们为啥ChatGPT能做到这么强的效果,只能说OpenAI大概有一些数据集构建的技术,另外也许有一些神秘的建模逻辑没有公开……这里期待和大家评论区里讨论讨论。

可以设想一种场景的比较:尝试直接用human feedback作监督信号,用第二步得到的RM模型的打分作为finetune模型的loss权重来finetune一个gpt-3,这样做和引入RL的训练效果比较如何呢?

另外一个问题在于,如果持续训练这样一个模型,依然会遭遇灾难性遗忘的问题……作为搜索引擎估计是不太行了,在张俊林老师的文章里也有一些相关的讨论,也许未来是二者相辅相成的发展模式。

关于ChatGPT在未来是不是能做成AGI,商业应用的问题,我的姿势水平有限就暂时不评价了;但是不觉得这很酷吗?作为一名理工男我觉得这太酷了,很符合我对未来生活的想象,科技并带着趣味。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/133297789