Llama-2 を例として、生成されたモデルでカスタム StoppingCriteria を使用します。

Llama-2 を例として、生成されたモデルでカスタム StoppingCriteria を使用します。

1 はじめに

前回の記事では、transformers モジュールを使用して作成されたモデル、その生成メソッドの詳細な原​​理と使用法が紹介されました。記事のリンクは次のとおりです。

ビームサーチを例に、変圧器での生成方法を詳しく説明(上)
ビームサーチを例に、変圧器での生成方法を詳しく説明(下)

ここでは、生成プロセスへのユーザー参加の 2 つの主要なコンポーネントについて説明しており、logits_processorこれらstopping_criteria2 つのクラスを使用する は、ユーザーが生成プロセスを制御する主な手段です。このうち、logits_processor生成処理中にユーザーが設定した指定ルールに従って語彙空間上の現在のステップの確率分布を強制的に変更したり、ユーザーが指定したルールに従って生成を停止したりするために使用されますstopping_criteria

これら 2 つのコンポーネントtransformersには、モジュール内に直接使用できるプリセット クラスがいくつかあります。プリセット クラスの基本情報の紹介については、トランスフォーマでの生成メソッドの説明 (その 1) のビーム サーチを例として参照してください。

この記事では、実際のアプリケーション シナリオを組み合わせて、ユーザーが自分のニーズに応じてカスタム シナリオを設計および実装しstopping_criteria、生成プロセスの早期終了を制御する方法を紹介します。

2. シナリオの紹介

今回紹介するシーンは、Llama-2 の生成機能を利用してニュースを要約するもので、ニュースで起こった核心的な事柄を要約する短い文を生成したいと考えています。

対話の背景を考慮し、歴史的なサンプルと組み合わせることで、Llama-2 が期待どおりの結果を出力できることを期待しています。

ダイアログのプロンプト構築方法については、以前のコンテンツ「NLP 実践 - Llama-2 マルチラウンド ダイアログ プロンプト構築」を参照してください。

ただし、コンテキスト内学習を行ったとしても、Llama-2 によって生成される結果は依然として冗長すぎます。

たとえば、次のようなニュース記事の場合:

text = """, Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ..."""  
# 后边忽略若干内容

モデルによって生成された結果は次のとおりです。

Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This year's exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容

モデルによって生成された結果は悪くないことがわかりますが、冗長すぎるため、私のニーズでは、モデルは最初の文を出力するだけで済みます。

このとき、「じゃあ文を分けて最初の文は残しておけばいいのでは?」と思う人もいるかもしれません。

——これでも効果は得られますが、生成プロセス、時間、計算能力が消費されています。

したがって、最初のフルストップが発生した時点でモデルの生成を停止し、結果を返す方法を採用する必要があります。したがって、今日の主役である停止基準を使用する必要があります。

3. 解決策

変圧器モジュールにはいくつかのデフォルトの停止基準が組み込まれていますが、多くの場合、要件を満たすことができません。この場合、カスタムの停止基準を作成する必要があります。

まず、基本クラスを参照する必要があります。

from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
    STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

で、

  • StoppingCriteriaList はコンテナであり、すべての条件をそれに追加する必要があり、このコンテナは生成時に渡されます。
  • StoppingCriteria は基本クラスであり、カスタム基準はこの基本クラスから継承する必要があります。

次に、基準を実装します。その結果、指定されたトークンが検出されると、生成が停止されます。

class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时,立即停止生成
    ---------------
    ver: 2023-08-02
    by: changhongyu
    """
    def __init__(self, token_id_list: List[int] = None):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_id_list = token_id_list
        
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list
        # 储存scores会额外占用资源,所以直接用input_ids进行判断
        return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list

次に、ピリオドが発生したときに生成を停止する場合は、ピリオドに対応する token_id を使用してそのような停止基準をインスタンス化し、それをコンテナーに追加します。

# Llama-2的词表中,英文句号的id是29889
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[29889]))

次に、生成時に、元の生成コマンドが次の場合、

model.generate(**inputs)

次に、停止基準をパラメータとして渡すと、機能します。

model.generate(stopping_criteria=stopping_criteria, **inputs)

4. 結論

停止基準は、各ステップ生成の終了時に生成プロセスを終了するかどうかを判断するために使用されます。ユーザーが生成プロセスを制御するための有効な手段です。動作方法も比較的簡単です。カスタム実装は複雑ではありません。このクラスの call メソッドの戻り値が bool 値であることを確認するだけでよく、これはあらゆる状況に対応できます。

ロジッツプロセッサもユーザー制御生成に有効なツールです。次回のブログではカスタムロジッツプロセッサの使い方を紹介しますので、興味のある学生は引き続き注目してください。

おすすめ

転載: blog.csdn.net/weixin_44826203/article/details/132089110