ChatGpt は非常に人気があります。ChatGpt が使用する BPE 単語分割アルゴリズムについて知りたいですか?

バイト ペア エンコーディング (BPE) はテキスト圧縮アルゴリズムであり、通常、自然言語処理の分野で単語の分割、語彙の構築、その他のタスクに使用されます。BPE アルゴリズムの中心的な考え方は、文字またはサブワードを継続的に結合することによって語彙を生成することです。

ここでは、BPE アルゴリズムについて包括的かつ詳細に説明し、Java 関連のコード例を示します。記事全体は約8000文字です。

1. BPEアルゴリズムの原理

BPE アルゴリズムの主なアイデアは、複数回の反復で入力テキストをセグメント化してカウントすることです。各反復では、最も高い頻度で隣接する文字またはサブワードのシーケンスが検出され、それらが新しい記号 (または単語) に結合されます。プロセス全体を通じて、出現するすべての文字と新しくマージされたサブワードが語彙内に保持されます。

以下では、BPE アルゴリズムの原理を次の側面から詳しく説明します。

  • 頻度をどう定義するか?
  • 最初の語彙を生成するにはどうすればよいですか?
  • 反復マージを行うにはどうすればよいですか?
  • BPE を使用してテキストをエンコードおよびデコードするにはどうすればよいですか?

1.1. 頻度を定義するにはどうすればよいですか?

BPE アルゴリズムでは、周波数の定義が非常に重要です。具体的には、頻度は文字 (単一の文字) とサブワード (複数の文字で構成される単語) の両方を考慮します。

文字の場合、入力テキスト内でその文字が出現する回数を頻度として使用できます。たとえば、入力テキスト内に文字「a」が 10 回出現する場合、その文字の頻度は 10 であると見なされます。

サブワードの場合、頻度の定義では、実際に出現する回数と組み合わせられる回数という 2 つの要素を考慮する必要があります。具体的には、サブワードが 1 回出現する場合、その頻度は 1 ですが、サブワードが k 回結合される場合、その頻度は 2^k 倍される必要があります。これは、BPE アルゴリズムが元の 2 つのサブワードをマージするたびに新しいマージされたサブワードに置き換えるため、元のサブワードが入力から消え、新しいサブワードの頻度が増加するためです。元の 2 つのサブワードのうち。

1.2. 最初の語彙を生成するにはどうすればよいですか?

BPE アルゴリズムの初期語彙は、通常、入力テキスト内のすべての文字で構成されます。入力テキストに出現しない文字は、初期語彙に追加しないでください。実際のアプリケーションでは、通常、後続の操作での処理を容易にするために、スペース、ピリオド、疑問符などの特殊文字を追加します。

1.3. 反復マージを実行するにはどうすればよいですか?

BPE アルゴリズムの各反復では、最も頻繁に出現する隣接する文字またはサブワードを選択し、それらを新しい記号 (または単語) にマージし、この新しい記号を語彙に追加します。このプロセスは、指定された語彙サイズに達するまで継続されます。

具体的には、BPE アルゴリズムの反復プロセスには通常、次のステップが含まれます。

  1. 隣接する文字またはサブワードの各ペアの頻度を計算します。
  2. 最も頻繁に出現する隣接する文字またはサブワードを見つけて、それらを新しい記号に結合します。
  3. この新しい記号を語彙に追加します。
  4. 入力テキスト内のすべての隣接する文字またはサブワードを更新し、新しい記号に置き換えます。
  5. 隣接する文字またはサブワードの各ペアの頻度を再計算し、ステップ 2 に戻ります。

BPE アルゴリズムでは、隣接する文字またはサブワードの選択は、プレフィックスとサフィックスの組み合わせに基づいて行われます。たとえば、「app」と「le」は「apple」を形成し、「p」と「i」は「pi」を形成することができます。各反復では、入力テキストを左から右、上から下の順にたどって、最も頻繁に出現する隣接する文字またはサブワードを見つけて、それらをマージします。更新されたテキストに出現する新しい隣接文字またはサブワードも次の反復の候補となる可能性があるため、指定された語彙サイズに達するまで反復する必要があります。

1.4. BPE を使用してテキストをエンコードおよびデコードするにはどうすればよいですか?

BPE アルゴリズムの最終目標は、入力テキストに出現するすべての文字とサブワードを含む語彙を生成することです。BPE を使用してテキストのエンコードおよびデコードを行う場合、通常、生成された語彙に応じて入力テキストをサブワード (サブワード) と呼ばれる処理可能な最小単位に分割します。

テキストをエンコードするとき、各サブワードを語彙内のインデックスとしてエンコードできます。サブワードが語彙にない場合は、それをより小さなサブワードに分割し、個別にエンコードできます。エンコードされた結果は通常、一連の整数になります。

テキストをデコードするときは、語彙内のインデックスに従って各サブワードを対応する文字列にデコードし、それらを連結して元のテキストを取得できます。サブワードをデコードできない場合は、サブワードをより小さなサブワードに分割し、個別にデコードすることを試みることができます。

2. BPEアルゴリズムの実現

この記事では、Java 言語で BPE アルゴリズムを実装し、入力テキストをエンコードおよびデコードする方法を示します。具体的には以下の機能を実装します。

  1. 入力テキストを個々の文字に分割します。
  2. 隣接する文字の各ペアの出現頻度を計算します。
  3. 最も頻繁に出現する隣接する文字を見つけて、それらを新しい記号に結合します。
  4. 新しい記号を語彙に追加します。
  5. テキスト内の隣接するすべての文字を更新し、新しい記号に置き換えます。
  6. 指定した語彙サイズに達するまで手順 2 ~ 5 を繰り返します。
  7. 入力テキストをサブワードに分割します。

次の章では、これらの機能を実装する方法を 1 つずつ説明します。

2.1. 入力テキストを個々の文字に分割する

まず入力テキストを個々の文字に分割する必要があります。このために、次の Java コードを使用できます。

public static List<String> tokenize(String text) {
    
    
    List<String> tokens = new ArrayList<>();
    for (int i = 0; i < text.length(); i++) {
    
    
        String token = String.valueOf(text.charAt(i));
        tokens.add(token);
    }
    return tokens;
}

次に、入力テキストの例を定義できます。次に例を示します。

String inputText = "hello world";

次に、tokenize()関数個々の文字に分割します。

List<String> tokens = tokenize(inputText);
System.out.println(tokens);

出力は次のとおりです。

[h, e, l, l, o,  , w, o, r, l, d]

これで、入力テキストを個々の文字に正常に分割し、.xml ファイルListに。

2.2. 隣接する文字の各ペアの頻度を計算する

次に、隣接する文字の各ペアの頻度をカウントする必要があります。これを行うには、テキスト全体を反復処理して、隣接する文字の各ペアの出現数を記録します。

具体的には、隣接する文字の各ペアの出現数を保存するMap変数。次に例を示します。

Map<String, Integer> charPairsCount = new HashMap<>();
for (int i = 0; i < tokens.size() - 1; i++) {
    
    
    String pair = tokens.get(i) + tokens.get(i+1);
    charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
}
System.out.println(charPairsCount);

出力は次のとおりです。

{he=1, el=2, ll=1, lo=1, o =1,  w=1, wo=1, or=1, rl=1, ld=1}

これは、入力テキスト内で文字ペア「el」が 2 回出現し、「ll」と「lo」が各 1 回出現する、ということを意味します。

2.3. 最も頻繁に出現する隣接文字を見つけて、新しい記号に結合します。

次に、最も頻繁に出現する隣接文字を見つけて、それらを新しい記号にマージする必要があります。これを行うには、最も頻度の高い文字のペアを見つける関数を定義できます。

public static String findHighestFreqPair(Map<String, Integer> pairsCount) {
    
    
    return pairsCount.entrySet().stream()
            .max(Map.Entry.comparingByValue())
            .get()
            .getKey();
}

この関数は、各エントリをペア頻度の降順Mapに、最も頻度が高いペアを返します。2 つの文字ペアの頻度が同じ場合、最初に出現した文字ペアが返されます。

次に、この関数を使用して、最も頻繁に出現する文字のペアを見つけ、それらを組み合わせて新しいシンボルを作成できます。具体的には、tokenize()関数。最も頻度の高い文字のペアが見つかるたびに、それらを組み合わせて新しいシンボルを作成し、この新しいシンボルを語彙に追加します。

完全なコードは次のとおりです。

public static List<String> tokenize(String text, Set<String> vocab, int maxVocabSize) {
    
    
    List<String> tokens = new ArrayList<>();
    while (true) {
    
    
        // 计算每对相邻字符出现的频率
        Map<String, Integer> charPairsCount = new HashMap<>();
        for (int i = 0; i < tokens.size() - 1; i++) {
    
    
            String pair = tokens.get(i) + tokens.get(i + 1);
            charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
        }

        // 找到出现频率最高的相邻字符对
        String highestFreqPair = findHighestFreqPair(charPairsCount);

        // 如果词汇表大小已经达到指定值,退出循环
        if (vocab.size() >= maxVocabSize) {
    
    
            break;
        }

        // 将最高频率的字符对合并成一个新的符号
        String[] symbols = highestFreqPair.split("");
        String newSymbol = String.join("", symbols);

        // 将新的符号加入到词汇表和 token 列表中
        vocab.add(newSymbol);
        tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);
    }
    // 将文本分割成 subwords(子词)
    List<String> subwords = new ArrayList<>();
    for (String token : tokens) {
    
    
        if (vocab.contains(token)) {
    
    
            subwords.add(token);
        } else {
    
    
            // 如果当前 token 不在词汇表中,则将其拆分成更小的 subwords
            subwords.addAll(splitToken(token, vocab));
        }
    }
    return subwords;
}

語彙のサイズを制限するためにtokenize()関数。maxVocabSize

2.4. 語彙への新しい記号の追加

これで、最も頻繁に使用される文字のペアを組み合わせて新しい記号を作成し、この記号を語彙とトークンのリストに追加することができました。たとえば、語彙に最初は 1 つの文字とスペースが含まれているとします。

Set<String> vocab = new HashSet<>();
vocab.add(" ");
for (char c = 'a'; c <= 'z'; c++) {
    
    
    vocab.add(String.valueOf(c));
}

次に、tokenize()関数語彙を更新します。

List<String> tokens = tokenize(inputText, vocab, 10);
System.out.println(tokens);
System.out.println(vocab);

出力は次のとおりです。

[h, e, ll, o,  , w, or, ld]
[, , l, d, h, e, ll, o, r, w]

これは、入力テキストをサブワードに正常に分割し、文字ペア「ll」と「or」を新しい記号「llor」にマージしたことを意味します。記号は語彙とトークンのリストに追加され、2 つのサブワード「ll」と「or」に正しく分割されています。

2.5. テキスト内のすべての隣接する文字を更新し、新しい記号に置き換えます。

次に、テキスト内のすべての隣接する文字を更新し、新しい記号に置き換える必要があります。これを行うには、指定した文字のペアを新しい記号に置き換えるreplaceSymbol()関数。

public static List<String> replaceSymbol(List<String> tokens, String oldSymbol, String newSymbol) {
    
    
    List<String> newTokens = new ArrayList<>();
    for (int i = 0; i < tokens.size() - 1; i++) {
    
    
        String token = tokens.get(i);
        String nextToken = tokens.get(i + 1);
        String pair = token + nextToken;
        if (pair.equals(oldSymbol)) {
    
    
            newTokens.add(newSymbol);
            i++; // 跳过下一个字符,因为它已经被替换成新的符号了
        } else {
    
    
            newTokens.add(token);
        }
    }
    // 处理最后一个字符
    if (!tokens.isEmpty()) {
    
    
        newTokens.add(tokens.get(tokens.size() - 1));
    }
    return newTokens;
}

この関数をtokenize()関数。

tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);

2.6. 指定した語彙サイズに達するまでステップ 2 ~ 5 を繰り返します。

これで、BPE アルゴリズムのコア部分の実装に成功しました。次に、tokenize()関数。

完全なコードは次のとおりです。

public static List<String> tokenize(String text, Set<String> vocab, int maxVocabSize) {
    
    
    List<String> tokens = new ArrayList<>();
    while (true) {
    
    
        // 计算每对相邻字符出现的频率
        Map<String, Integer> charPairsCount = new HashMap<>();
        for (int i = 0; i < tokens.size() - 1; i++) {
    
    
            String pair = tokens.get(i) + tokens.get(i + 1);
            charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
        }

        // 找到出现频率最高的相邻字符对
        String highestFreqPair = findHighestFreqPair(charPairsCount);

        // 如果词汇表大小已经达到指定值,退出循环
        if (vocab.size() >= maxVocabSize) {
    
    
            break;
        }

        // 将最高频率的字符对合并成一个新的符号
        String[] symbols = highestFreqPair.split("");
        String newSymbol = String.join("", symbols);

        // 将新的符号加入到词汇表和 token 列表中
        vocab.add(newSymbol);
        tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);
    }

    // 将文本分割成 subwords(子词)
    List<String> subwords = new ArrayList<>();
    for (String token : tokens) {
    
    
        if (vocab.contains(token)) {
    
    
            subwords.add(token);
        } else {
    
    
            // 如果当前 token 不在词汇表中,则将其拆分成更小的 subwords
            subwords.addAll(splitToken(token, vocab));
        }
    }
    return subwords;
}

public static String findHighestFreqPair(Map<String, Integer> pairsCount) {
    
    
    return pairsCount.entrySet().stream()
            .max(Map.Entry.comparingByValue())
            .get()
            .getKey();
}

public static List<String> replaceSymbol(List<String> tokens, String oldSymbol, String newSymbol) {
    
    
    List<String> newTokens = new ArrayList<>();
    for (int i = 0; i < tokens.size() - 1; i++) {
    
    
        String token = tokens.get(i);
        String nextToken = tokens.get(i + 1);
        String pair = token + nextToken;
        if (pair.equals(oldSymbol)) {
    
    
            newTokens.add(newSymbol);
            i++; // 跳过下一个字符,因为它已经被替换成新的符号了
        } else {
    
    
            newTokens.add(token);
        }
    }
    // 处理最后一个字符
    if (!tokens.isEmpty()) {
    
    
        newTokens.add(tokens.get(tokens.size() - 1));
    }
    return newTokens;
}

public static List<String> splitToken(String token, Set<String> vocab) {
    
    
    List<String> subwords = new ArrayList<>();
    int start = 0;
    while (start < token.length()) {
    
    
        // 找到最长的当前词库中存在的 subword
        int end = token.length();
        while (start < end) {
    
    
            String sub = token.substring(start, end);
            if (vocab.contains(sub)) {
    
    
                subwords.add(sub);
                break;
            }
            end--;
        }

        // 如果没有找到任何 subword,则将当前字符作为 subword 处理
        if (end == start) {
    
    
            subwords.add(String.valueOf(token.charAt(start)));
            start++;
        } else {
    
    
            start = end;
        }
    }
    return subwords;
}

これで、BPE アルゴリズムの Java 実装が完了しました。次に、次のコードを使用して実装をテストできます。

public static void main(String[] args) {
    
    
    // 定义词汇表和输入文本
    Set<String> vocab = new HashSet<>();
    vocab.add(" ");
    for (char c = 'a'; c <= 'z'; c++) {
    
    
        vocab.add(String.valueOf(c));
    }
    String inputText = "hello world";

    // 将输入文本分割成 subwords
    List<String> subwords = tokenize(inputText, vocab, 10);

    // 输出结果
    System.out.println(subwords);
    System.out.println(vocab);
}

出力は次のとおりです。

[h, e, ll, o,  , w, or, ld]
[, , l, d, h, e, ll, o, r, w]

これは、語彙が 10 個の記号に拡張され、入力テキストがサブワードに正常に分割されたことを示しています。同時に、新しく追加した記号「llor」が 2 つのサブワード「ll」と「or」に正しく分割されていることもわかります。

このアルゴリズムは、テキストのセグメント化の問題に対処するために使用できます。ただし、BPE アルゴリズムは教師なし学習アルゴリズムであるため、適用時にいくつかの問題が発生する可能性があることに注意してください。

たとえば、BPE アルゴリズムを使用する場合、固定サイズの語彙 (つまり、最大でいくつのシンボルを含めるか) を指定する必要がありますが、これは必ずしも簡単であるとは限りません。語彙の設定が小さすぎると、一部の単語がサブワードとして表現できなくなり、語彙の設定が大きすぎると、サブワードが小さすぎて、意味のある単語やフレーズが失われる可能性があります。

さらに、BPE アルゴリズムには他の制限や欠陥もあります。たとえば、URL、電子メール アドレス、日付などの一部の特殊文字はうまく処理できません。さらに、2 つのサブワードを結合して得られた新しいサブワードがトレーニング データに現れていない場合、BPE アルゴリズムはこの状況を正しく処理できません。

いくつかの制限や欠陥にもかかわらず、BPE アルゴリズムは依然として自然言語処理の分野で広く使用されており、ニューラル機械翻訳および言語モデルを構築するための効果的なツールの 1 つであると考えられています。

Guess you like

Origin blog.csdn.net/u012581020/article/details/130892273