Paddle2.0 + CPM-LM:让AI帮你写文章吧

引入

  • 之前的两个项目分别介绍了GPT-2模型的构建,和如何使用GPT-2加载CPM模型实现问答机器人
  • 由于之前项目所使用的解码方式为Greedy_Search,生成的文本固定且单一,并不适合与写作类型的文本生成
  • 本次就修改一下之前项目的解码方式,改为采样解码,实现AI文本续写

快速体验

  • 可以在百度AIStudio平台上快速体验文本续写生成:项目链接

效果展示

  • 照例先展示一般文本续写的生成效果

输入:《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,
输出:《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,是整个社会理论体系的构成部分,从宏观的解剖上,选择了中国社会的十个典型实例,以乡土中国为主体,传达出一个宏廓的概念和理念,此为其特色,能引起很多人的共鸣。
输出:《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,代表了他在“社会整合的理论与实践”、“社会流动的理论与实践”和“中国传统社会”三个专题上的研究成果。《中国农民调查》(原名《中国农村社会调查》)是费孝通先生在“乡村研究”的基础上,从社会学角度探讨中国农民问题的重要著作,在一定程度上填补了我国的“农民社会学”研究领域的空白。
输出:《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,出版于1984年。书中讨论的“中国社会结构”有两层含义,一层是经济结构,另一层是政治结构。费孝通先生认为,“中国社会结构的变化就是政治结构变化的结果。”具体说来,费孝通先生把中国社会的各种结构分为:“分”、“层”和“级”三级。

输入:在此处输入文本的开头,
输出:在此处输入文本的开头,再按下空格键,就可以开始朗读。
输出:在此处输入文本的开头,然后在“查找与替换”对话框中输入替换文本的前缀,然后在“替换”对话框中用IF 命令选择需要替换的字符。
输出:在此处输入文本的开头,然后再输入下一个字符 ,就可以得到一个"HelloWorld!"的文本。

输入:方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。
输出:方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。他负手彷徨,曾经在铜宫带领大军,横扫天下的军师如今只是位数趾高气扬的小卒,如今自己,连他身边的一个随从的支使就算不会武功也是位高权重。横刀立马,换来的是什么?他不知道,今天他走路都有些飘摇。
输出:方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。 “来者何人?” “我乃黑山军司使杨平,奉圣命前来查探敌情。

输入:黛玉因在一风雨交加的夜晚被拒怡红院后,
输出:黛玉因在一风雨交加的夜晚被拒怡红院后,前去投靠贾府的柳湘莲。虽然他们并没有见过面,但是因为是旧时挚友,黛玉对柳湘莲有一种说不出来的情愫。有一次,黛玉自窗外经过,忽然看见柳湘莲的那棵柳树上结了一颗小小的红豆,便摘了下来,要送给柳湘莲,但柳湘莲却推托不要。
输出:黛玉因在一风雨交加的夜晚被拒怡红院后,去往荣府,在贾母处说了些闲话,便被邢夫人打发了。 黛玉回到自己的屋中,想着母亲那“不过是我的东西,因为不归我管,我还能怨谁”,于是便拿着扫帚打扫房间,想要打扫干净,但是想起了怡红院里的事情,便打定了主意,“还是先将就着用吧”。

  • 是不是感觉上面生成的文本还不错,那当然我挑出来展示的能不好吗
  • 使用采样的方式进行文本生成确实能让生成的文本更加多样和生动,更像人写出来的文字。
  • 但是由于这样的文本生成不太可控,所有经常会生成一些奇奇怪怪的文字,就像下面这些:

输入:黛玉因在一风雨交加的夜晚被拒怡红院后,
输出:黛玉因在一风雨交加的夜晚被拒怡红院后,怀恨在心而嫁祸宝玉。鹿肉丸里有哪些成分可以推断出「有毒」的? 所用的肌肉攻击/溶血/激素…这些成分能推断出有毒吗? 文科生如何看得懂这些问题… 这些成份,貌似需要少量的,但是真的有毒,怎么致死的呢? 刚刚在微博上看的。没有典型的肉丸类(丸类有没有,疑问…问日下的对吧),大多都是糖

输入:方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。
输出:方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。 “还有谁?”刘关张三人围着残雪的方向站定,“我还有一个问题。

解码方式

  • 此次采用的是 top-k filtering 和 nucleus filtering 这两种解码方式
  • 参考的代码为GPT-2 Chinese中解码函数
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits
  • 因为一些功能貌似Paddle Tensor无法直接使用,所以就使用了Numpy作为中转(也可能是我菜,没发现怎么实现)
  • 下面是自己使用Paddle2.0实现的代码,仅供参考
    def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
        """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
            Args:
                logits: logits distribution shape (vocabulary size)
                top_k > 0: keep only top k tokens with highest probability (top-k filtering).
                top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                    Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
        """
        top_k = min(top_k, logits.shape[-1])  # Safety check
        logits_np = logits.numpy()
        if top_k > 0:
            # Remove all tokens with a probability less than the last token of the top-k
            indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
            logits_np[indices_to_remove] = filter_value

        if top_p < 1.0:
            sorted_logits = paddle.sort(logits, descending=True)
            sorted_indices = paddle.argsort(logits, descending=True).numpy()
            cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits_np[indices_to_remove] = filter_value

        return paddle.to_tensor(logits_np)

总结

  • CPM-LM在不微调的情况下能够生成出质量还不错的文本
  • 但也就只是还不错而已,远远称不上优秀,更别说时不时会生成跑偏
  • 而且模型过大导致无法实现长文本直接生成,在32G的显存环境中也只能勉强进行200个token运算
  • 期待以后能出现更多生成效果更加自然的模型

猜你喜欢

转载自blog.csdn.net/jm_12138/article/details/111599530
今日推荐