NLP テキスト マッチング タスク テキスト マッチング [教師ありトレーニング]: PointWise (単一タワー)、DSSM (2 タワー)、Sentence BERT (2 タワー) プロジェクトの実践

NLP テキスト マッチング タスク テキスト マッチング [教師ありトレーニング]: PointWise (単一タワー)、DSSM (2 タワー)、Sentence BERT (2 タワー) プロジェクトの実践

0 背景の紹介と関連概念

このプロジェクトでは、一般的に使用される 3 つのテキスト マッチング方法、PointWise (単一タワー)、DSSM (2 タワー)、および Sentence BERT (2 タワー) を実装します。

テキスト マッチング (テキスト マッチング) は NLP の一分野であり、通常 2 つの文間の類似性を計算するために使用され、推奨や推論などのシナリオで重要な役割を果たします。

たとえば、今日は大量のコメント データがあるため、指定されたカテゴリのコメント データを検索したいとします。次に例を示します。

1. 为什么是开过的洗发水都流出来了,是用过的吗?是这样子包装的吗?
2. 喜欢折叠手机的我对这款手机情有独钟,简洁的外观设计非常符合当代年轻人的口味,给携带增添了一份愉悦。
3. 物流很快,但是到货时有的水果已经不新鲜了,坏掉了,不满意本次购物。
...

このコメントの山の中で、「果物」に関連するコメントを見つけたい場合は、3 番目のコメントを呼び出す必要があります。この問題はテキスト分類としてモデル化できますよね? テキスト分類モデルをトレーニングすることで同じ目標を達成できます。

ただし、分類モデルの主な問題は、分類ラベルが固定されていることです。トレーニング中に設定されたラベルが「入浴、コンピューター、果物」で、今日は「衣服」に関する別のコメントがある場合、モデルは、どれが間違っているとしても、元のラベル セットで推論することしかできません。したがって、ある程度まで適応でき、新しいラベルを追加した後にモデルを再トレーニングすることなくタスクを適切に完了できるモデルが必要です。

現在、テキスト マッチングに一般的に使用される構造が 2 つあります。

  1. シングルタワーモデル: 高精度ですが、計算速度が遅くなります。

  2. 2塔モデル: 計算速度は速いですが、精度は比較的低くなります。

以下では、これら 2 つの方法を個別に紹介します。

0.1 シングルタワーモデル

名前が示すように、シングルタワー モデルは、プロセス全体で 1 つのモデル計算のみが実行されることを意味します。ここでの「タワー」とは「複数のモデルの計算」を指しており、必ずしも「モデルの数」を指すわけではありません。ツインタワーの部分で説明します。シングルタワー モデルでは、[SEP] を通じて 2 つのテキスト文を結合し、結合されたデータをモデルにフィードし、出力内の [CLS] トークンを通じてバイナリ分類タスクを実行する必要があります。

単一タワー モデルの前部は次のようになります。完全なソース コードは記事の最後にあります。

    def __init__(self, encoder, dropout=None):
        """
        init func.
        Args:
            encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
            dropout (float): dropout 比例
        """
        super().__init__()
        self.encoder = encoder
        hidden_size = 768
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
        self.classifier = nn.Linear(768, 2)

    def forward(self,
                input_ids,
                token_type_ids,
                position_ids=None,
                attention_mask=None) -> torch.tensor:
        """
        Foward 函数,输入匹配好的pair对,返回二维向量(相似/不相似)。
        Args:
            input_ids (torch.LongTensor): (batch, seq_len)
            token_type_ids (torch.LongTensor): (batch, seq_len)
            position_ids (torch.LongTensor): (batch, seq_len)
            attention_mask (torch.LongTensor): (batch, seq_len)

        Returns:
            torch.tensor: (batch, 2)
        """
        pooled_embedding = self.encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            attention_mask=attention_mask
        )["pooler_output"]                                  # (batch, hidden_size)
        pooled_embedding = self.dropout(pooled_embedding)   # (batch, hidden_size)
        logits = self.classifier(pooled_embedding)          # (batch, 2)
        
        return logits

シングルタワーモデルの利点は正解率が高いことですが、欠点は計算が遅いことです。

  • なぜ遅いのでしょうか?

たとえば、今日「コンピューター、果物、お風呂」という 3 つのカテゴリがある場合、各カテゴリで文をつなぎ合わせ、推論のためにモデルにフィードする必要があります。

水果[SEP]苹果不是很新鲜,不满意这次购物[SEP]
电脑[SEP]苹果不是很新鲜,不满意这次购物[SEP]
洗浴[SEP]苹果不是很新鲜,不满意这次购物[SEP]

カテゴリの数が数百、さらには数千に達すると、何千回も結合する必要があり、サンプルを判断するには最後のモデルを通過する必要があり、大規模なモデルの計算には通常非常に時間がかかります。多くの場合、単一タワー モデルの効率では人々のニーズを満たせないことがよくあります。

0.2 2塔モデル

単一タワー モデルの欠点は明らかで、多くのカテゴリを何度もカウントする必要があるためです。しかし実際には、これらのカテゴリは変更されず、変更されるのは新しいコメント データだけです。では、これらの変化しない「カテゴリー情報」と「事前計算」を保存しておいて、見ていない「コメントデータ」だけを計算することはできるのでしょうか?これがツインタワーモデルの考え方です。2 タワー モデルの「2 つのタワー」の意味は、2 つのモデルの計算です。つまり、カテゴリ特徴は 1 回計算され、コメント特徴も 1 回計算されます。

上の図からわかるように、「カテゴリー」と「コメント」は一緒に結合されてモデルに供給されるのではなく、別々にモデルに供給され、それぞれの埋め込みベクトルが後続の計算のために取得されます。上図では、左側と右側のモデルは同じモデルを使用することもできます (この方法は同型モデルと呼ばれます)。または、2 つの異なるモデルを使用することもできます (この方法は異種モデルと呼ばれます)。したがって、「ツイン タワー」は必ずしも 2 つのモデルがあることを意味するのではなく、2 つのモデルの計算を実行する必要があることを意味します。

0.2.1 DSSM (深層構造セマンティックモデル、深層構造セマンティックモデル)

論文リファレンス: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf

DSSM は初期の論文であり、私たちは主に、埋め込み間のコサイン類似性によるリコールソートのアイデアを利用しています。「カテゴリ」テキストと「コメント」テキストをモデルに別々に渡し、2 つのテキストの埋め込みを取得します。一致するペア間のコサイン類似度ラベルを 1 に設定し、一致しないペア間のコサイン類似度ラベルを 0 に設定します。

注: コサイン類似度の値の範囲は [-1, 1] ですが、便宜上、ラベルを 0 に設定し、MSE を使用してトレーニングします。これでも良好な結果が得られます。

トレーニング データ セットは次のようになり、一致するペアのラベルは 1、不一致は 0 です。

蒙牛	不错还好挺不错	0
蒙牛	我喜欢demom制造的蒙牛奶	1
衣服	裤子太差了,刚穿一次屁股就起毛了。	1
...

実装には 2 つの主要な関数があります。1 つは文の埋め込み関数の取得 (推論用)、もう 1 つは文のペアのコサイン類似度の取得 (トレーニング用) です。

    def forward(
        self,
        input_ids: torch.tensor,
        token_type_ids: torch.tensor,
        attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        forward 函数,输入单句子,获得单句子的embedding。
        Args:
            input_ids (torch.LongTensor): (batch, seq_len)
            token_type_ids (torch.LongTensor): (batch, seq_len)
            attention_mask (torch.LongTensor): (batch, seq_len)
        Returns:
            torch.tensor: embedding -> (batch, hidden_size)
        """
        embedding = self.encoder(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask
            )["pooler_output"]                                  # (batch, hidden_size)
        return embedding

    def get_similarity(
        self,
        query_input_ids: torch.tensor,
        query_token_type_ids: torch.tensor,
        query_attention_mask: torch.tensor,
        doc_input_ids: torch.tensor,
        doc_token_type_ids: torch.tensor,
        doc_attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        输入query和doc的向量,返回query和doc两个向量的余弦相似度。
        Args:
            query_input_ids (torch.LongTensor): (batch, seq_len)
            query_token_type_ids (torch.LongTensor): (batch, seq_len)
            query_attention_mask (torch.LongTensor): (batch, seq_len)
            doc_input_ids (torch.LongTensor): (batch, seq_len)
            doc_token_type_ids (torch.LongTensor): (batch, seq_len)
            doc_attention_mask (torch.LongTensor): (batch, seq_len)
        Returns:
            torch.tensor: (batch, 1)
        """
        query_embedding = self.encoder(
            input_ids=query_input_ids,
            token_type_ids=query_token_type_ids,
            attention_mask=query_attention_mask
        )["pooler_output"]                                 # (batch, hidden_size)
        query_embedding = self.dropout(query_embedding)

        doc_embedding = self.encoder(
            input_ids=doc_input_ids,
            token_type_ids=doc_token_type_ids,
            attention_mask=doc_attention_mask
        )["pooler_output"]                                  # (batch, hidden_size)
        doc_embedding = self.dropout(doc_embedding)

        similarity = nn.functional.cosine_similarity(query_embedding, doc_embedding)
        return similarity

0.2.2 センテンストランスフォーマー

論文参考:https://arxiv.org/pdf/1908.10084.pdf

Sentence Transformer も 2 タワー モデルですが、トレーニング中に 2 つの文のエンベディングのコサイン類似度を直接計算するのではなく、2 つのエンベディングとエンベディングの間の差分ベクトルを結合し、3 つのベクトルが組み立てられた後にフィードします。二値分類タスクの識別層に。

元の論文では、トレーニング アーキテクチャは推論に使用されなくなりましたが、コサイン類似度の方法が再現に使用されています。ただし、実装中に推論部分でトレーニング済みのモデル アーキテクチャを引き続き使用しました。一貫性のない構造によるギャップを消去したかったためです。また、トレーニング層はもう 1 つの線形層だけであり、推論中にあまり時間がかからないようになりました。 .時間。Sentence Transformer は、次のように、推論中に「現在のコメント情報」と事前に計算された「すべてのカテゴリの埋め込み」を同時に渡す必要があります。

    def forward(
        self,
        query_input_ids: torch.tensor,
        query_token_type_ids: torch.tensor,
        query_attention_mask: torch.tensor,
        doc_embeddings: torch.tensor,
    ) -> torch.tensor:
        """
        forward 函数,输入query句子和doc_embedding向量,将query句子过一遍模型得到
        query embedding再和doc_embedding做二分类。
        Args:
            input_ids (torch.LongTensor): (batch, seq_len)
            token_type_ids (torch.LongTensor): (batch, seq_len)
            attention_mask (torch.LongTensor): (batch, seq_len)
            doc_embedding (torch.LongTensor): 所有需要匹配的doc_embedding -> (batch, doc_embedding_numbers, hidden_size)
        Returns:
            torch.tensor: embedding_match_logits -> (batch, doc_embedding_numbers, 2)
        """
        query_embedding = self.encoder(
            input_ids=query_input_ids,
            token_type_ids=query_token_type_ids,
            attention_mask=query_attention_mask
        )["last_hidden_state"]                                                  # (batch, seq_len, hidden_size)
        
        query_attention_mask = torch.unsqueeze(query_attention_mask, dim=-1)    # (batch, seq_len, 1)
        query_embedding = query_embedding * query_attention_mask                # (batch, seq_len, hidden_size)
        query_sum_embedding = torch.sum(query_embedding, dim=1)                 # (batch, hidden_size)
        query_sum_mask = torch.sum(query_attention_mask, dim=1)                 # (batch, 1)
        query_mean = query_sum_embedding / query_sum_mask                       # (batch, hidden_size)

        query_mean = query_mean.unsqueeze(dim=1).repeat(1, doc_embeddings.size()[1], 1)  # (batch, doc_embedding_numbers, hidden_size)
        sub = torch.abs(torch.subtract(query_mean, doc_embeddings))                      # (batch, doc_embedding_numbers, hidden_size)
        concat = torch.cat([query_mean, doc_embeddings, sub], dim=-1)                    # (batch, doc_embedding_numbers, hidden_size * 3)
        logits = self.classifier(concat)                                                 # (batch, doc_embedding_numbers, 2)
        return logits

1. 環境のインストール

このプロジェクトはpytorch+の実装に基づいていますtransformers。実行する前に関連する依存関係パッケージをインストールしてください。

pip install -r ../../requirements.txt

torch
transformers==4.22.1
datasets==2.4.0
evaluate==0.2.2
matplotlib==3.6.0
rich==12.5.1
scikit-learn==1.1.2
requests==2.28.1

2. データセットの準備

プロジェクトではサンプルデータの一部を提供しています。テキストマッチングタスクには「商品レビュー」と「商品カテゴリ」を使用します。データは にありますdata/comment_classify

トレーニングを使用するには自定义数据、サンプル データのようなデータセットを構築するだけです。

衣服:指穿在身上遮体御寒并起美化作用的物品。	为什么是开过的洗发水都流出来了、是用过的吗?是这样子包装的吗?	0
衣服:指穿在身上遮体御寒并起美化作用的物品。	开始买回来大很多 后来换了回来又小了 号码区别太不正规 建议各位谨慎	1
...

各行は\tセパレータで区切られ、最初の部分は商品类型(text1)、中間部分は商品评论(text2)、最後の部分は です商品评论和商品类型是否一致(label)

3. 教師ありモデルのトレーニング

3.1 PointWise (単一タワー)

3.1.1 モデルのトレーニング

train_pointwise.shトレーニング スクリプト内の対応するパラメーターを変更して、モデルのトレーニングを開始します。

python train_pointwise.py \
    --model "nghuyong/ernie-3.0-base-zh" \  # backbone
    --train_path "data/comment_classify/train.txt" \    # 训练集
    --dev_path "data/comment_classify/dev.txt" \    #验证集
    --save_dir "checkpoints/comment_classify" \ # 训练模型存放地址
    --img_log_dir "logs/comment_classify" \ # loss曲线图保存位置
    --img_log_name "ERNIE-PointWise" \  # loss曲线图保存文件夹
    --batch_size 8 \
    --max_seq_len 128 \
    --valid_steps 50 \
    --logging_steps 10 \
    --num_train_epochs 10 \
    --device "cuda:0"

トレーニングが正しく開始されると、ターミナルは次の情報を出力します。

...
global step 10, epoch: 1, loss: 0.77517, speed: 3.43 step/s
global step 20, epoch: 1, loss: 0.67356, speed: 4.15 step/s
global step 30, epoch: 1, loss: 0.53567, speed: 4.15 step/s
global step 40, epoch: 1, loss: 0.47579, speed: 4.15 step/s
global step 50, epoch: 2, loss: 0.43162, speed: 4.41 step/s
Evaluation precision: 0.88571, recall: 0.87736, F1: 0.88152
best F1 performence has been updated: 0.00000 --> 0.88152
global step 60, epoch: 2, loss: 0.40301, speed: 4.08 step/s
global step 70, epoch: 2, loss: 0.37792, speed: 4.03 step/s
global step 80, epoch: 2, loss: 0.35343, speed: 4.04 step/s
global step 90, epoch: 2, loss: 0.33623, speed: 4.23 step/s
global step 100, epoch: 3, loss: 0.31319, speed: 4.01 step/s
Evaluation precision: 0.96970, recall: 0.90566, F1: 0.93659
best F1 performence has been updated: 0.88152 --> 0.93659
...

logs/comment_classifyトレーニング曲線グラフは次のファイルに保存されます。

3.1.2 モデル推論

モデルのトレーニングが完了したら、inference_pointwise.pyトレーニングされたモデルをロードして適用するために実行します。

...
    test_inference(
        '手机:一种可以在较广范围内使用的便携式电话终端。',     # 第一句话
        '味道非常好,京东送货速度也非常快,特别满意。',        # 第二句话
        max_seq_len=128
    )
...

推論プログラムを実行します。

python inference_pointwise.py

次の推論結果が得られます。

tensor([[ 1.8477, -2.0484]], device='cuda:0')   # 两句话不相似(0)的概率更大

3.2 DSSM (ツインタワー)

3.2.1 モデルのトレーニング

train_dssm.shトレーニング スクリプト内の対応するパラメーターを変更して、モデルのトレーニングを開始します。

python train_dssm.py \
    --model "nghuyong/ernie-3.0-base-zh" \
    --train_path "data/comment_classify/train.txt" \
    --dev_path "data/comment_classify/dev.txt" \
    --save_dir "checkpoints/comment_classify/dssm" \
    --img_log_dir "logs/comment_classify" \
    --img_log_name "ERNIE-DSSM" \
    --batch_size 8 \
    --max_seq_len 256 \
    --valid_steps 50 \
    --logging_steps 10 \
    --num_train_epochs 10 \
    --device "cuda:0"

トレーニングが正しく開始されると、ターミナルは次の情報を出力します。

...
global step 0, epoch: 1, loss: 0.62319, speed: 15.16 step/s
Evaluation precision: 0.29912, recall: 0.96226, F1: 0.45638
best F1 performence has been updated: 0.00000 --> 0.45638
global step 10, epoch: 1, loss: 0.40931, speed: 3.64 step/s
global step 20, epoch: 1, loss: 0.36969, speed: 3.69 step/s
global step 30, epoch: 1, loss: 0.33927, speed: 3.69 step/s
global step 40, epoch: 1, loss: 0.31732, speed: 3.70 step/s
global step 50, epoch: 1, loss: 0.30996, speed: 3.68 step/s
...

logs/comment_classifyトレーニング曲線グラフは次のファイルに保存されます。

3.2.2 モデル推論

シングルタワーモデルとは異なり、ツータワーモデルは、すべての候補カテゴリのエンベディングを事前に計算することができ、新しい文が来たときに、新しい文のエンベディングを計算し、コサイン類似度を通じて最適解を見つけるだけで済みます。

したがって、推論の前に、すべてのカテゴリの Embedding を事前に計算して保存する必要があります。

カテゴリ 埋め込み計算

get_embedding.pyファイルを実行して、対応するカテゴリの埋め込みを計算し、ローカルに保存します。

...
text_file = 'data/comment_classify/types_desc.txt'                       # 候选文本存放地址
output_file = 'embeddings/comment_classify/dssm_type_embeddings.json'    # embedding存放地址

device = 'cuda:0'                                                        # 指定GPU设备
model_type = 'dssm'                                                      # 使用DSSM还是Sentence Transformer
saved_model_path = './checkpoints/comment_classify/dssm/model_best/'     # 训练模型存放地址
tokenizer = AutoTokenizer.from_pretrained(saved_model_path) 
model = torch.load(os.path.join(saved_model_path, 'model.pt'))
model.to(device).eval()
...

このうち、事前計算が必要なすべてのコンテンツはtypes_desc.txtファイルに保存されます。

ファイルは、それぞれ\t表す で区切られています类别id类别名称类别描述

0	水果	指多汁且主要味觉为甜味和酸味,可食用的植物果实。
1	洗浴	洗浴用品。
2	平板	也叫便携式电脑,是一种小型、方便携带的个人电脑,以触摸屏作为基本的输入设备。
...

コマンドを実行するとpython get_embeddings.py、コードに設定された埋め込みストレージ アドレスで、対応する埋め込みファイルが見つかります。

{
    
    
    "0": {
    
    "label": "水果", "text": "水果:指多汁且主要味觉为甜味和酸味,可食用的植物果实。", "embedding": [0.3363891839981079, -0.8757723569869995, -0.4140555262565613, 0.8288457989692688, -0.8255823850631714, 0.9906797409057617, -0.9985526204109192, 0.9907819032669067, -0.9326567649841309, -0.9372553825378418, 0.11966298520565033, -0.7452883720397949,...]},
    "1": ...,
    ...
}

モデル推論

事前計算が完了したら、次のステップは推論を開始することです。

新しいコメントを作成します: 这个破笔记本卡的不要不要的,差评

実行するpython inference_dssm.pyと、次の結果が得られます。

[
    ('平板', 0.9515482187271118),
    ('电脑', 0.8216977119445801),
    ('洗浴', 0.12220608443021774),
    ('衣服', 0.1199738010764122),
    ('手机', 0.07764233648777008),
    ('酒店', 0.044791921973228455),
    ('水果', -0.050112202763557434),
    ('电器', -0.07554933428764343),
    ('书籍', -0.08481660485267639),
    ('蒙牛', -0.16164332628250122)
]

この関数は、(カテゴリ、コサイン類似度) のバイナリ グループを出力し、類似度に従って反転を実行します (類似度の値の範囲: [-1, 1])。

3.3 文トランスフォーマー (ツインタワー)

3.3.1 モデルのトレーニング

train_sentence_transformer.shトレーニング スクリプト内の対応するパラメーターを変更して、モデルのトレーニングを開始します。

python train_sentence_transformer.py \
    --model "nghuyong/ernie-3.0-base-zh" \
    --train_path "data/comment_classify/train.txt" \
    --dev_path "data/comment_classify/dev.txt" \
    --save_dir "checkpoints/comment_classify/sentence_transformer" \
    --img_log_dir "logs/comment_classify" \
    --img_log_name "Sentence-Ernie" \
    --batch_size 8 \
    --max_seq_len 256 \
    --valid_steps 50 \
    --logging_steps 10 \
    --num_train_epochs 10 \
    --device "cuda:0"

トレーニングが正しく開始されると、ターミナルは次の情報を出力します。

...
Evaluation precision: 0.81928, recall: 0.64151, F1: 0.71958
best F1 performence has been updated: 0.46120 --> 0.71958
global step 260, epoch: 2, loss: 0.58730, speed: 3.53 step/s
global step 270, epoch: 2, loss: 0.58171, speed: 3.55 step/s
global step 280, epoch: 2, loss: 0.57529, speed: 3.48 step/s
global step 290, epoch: 2, loss: 0.56687, speed: 3.55 step/s
global step 300, epoch: 2, loss: 0.56033, speed: 3.55 step/s
...

logs/comment_classifyトレーニング曲線グラフは次のファイルに保存されます。

3.2.2 モデル推論

Sentence Transformer も 2 つのタワー モデルであるため、すべての候補テキストの埋め込み値を事前に計算する必要があります。

カテゴリ 埋め込み計算

get_embedding.pyファイルを実行して、対応するカテゴリの埋め込みを計算し、ローカルに保存します。

...
text_file = 'data/comment_classify/types_desc.txt'                       # 候选文本存放地址
output_file = 'embeddings/comment_classify/sentence_transformer_type_embeddings.json'    # embedding存放地址

device = 'cuda:0'                                                        # 指定GPU设备
model_type = 'sentence_transformer'                                                      # 使用DSSM还是Sentence Transformer
saved_model_path = './checkpoints/comment_classify/sentence_transformer/model_best/'     # 训练模型存放地址
tokenizer = AutoTokenizer.from_pretrained(saved_model_path) 
model = torch.load(os.path.join(saved_model_path, 'model.pt'))
model.to(device).eval()
...

このうち、事前計算が必要なすべてのコンテンツはtypes_desc.txtファイルに保存されます。

ファイルは、それぞれ\t表す で区切られています类别id类别名称类别描述

0	水果	指多汁且主要味觉为甜味和酸味,可食用的植物果实。
1	洗浴	洗浴用品。
2	平板	也叫便携式电脑,是一种小型、方便携带的个人电脑,以触摸屏作为基本的输入设备。
...

コマンドを実行するとpython get_embeddings.py、コードに設定された埋め込みストレージ アドレスで、対応する埋め込みファイルが見つかります。

{
    
    
    "0": {
    
    "label": "水果", "text": "水果:指多汁且主要味觉为甜味和酸味,可食用的植物果实。", "embedding": [0.32447007298469543, -1.0908259153366089, -0.14340722560882568, 0.058471400290727615, -0.33798110485076904, -0.050156619399785995, 0.041511114686727524, 0.671889066696167, 0.2313404232263565, 1.3200652599334717, -1.10829496383667, 0.4710233509540558, -0.08577515184879303, -0.41730815172195435, -0.1956728845834732, 0.05548520386219025, ...]}
    "1": ...,
    ...
}

モデル推論

事前計算が完了したら、次のステップは推論を開始することです。

新しいコメントを作成します: 这个破笔记本卡的不要不要的,差评

を実行するとpython inference_sentence_transformer.py、関数はすべてのカテゴリで「一致した」カテゴリとその一致する値を出力し、次の結果を取得します。

Used 0.5233056545257568s.
[
    ('平板', 1.136274814605713), 
    ('电脑', 0.8851938247680664)
]

この関数は 2 つのタプル (一致するカテゴリ、一致する値) を出力し、一致する値に従って逆ランキングを実行します (値が大きいほど、一致度が高くなります)。

参考リンク:https://github.com/HarderThenHarder/transformers_tasks/blob/main/text_matching/supervised

github に接続できない場合は、https: //download.csdn.net/download/sinat_39620217/88214437からダウンロードできます。

さらに質の高いコンテンツについては、公式アカウント「Ting、人工知能」に注目してください。一部の関連リソースと質の高い記事が無料で提供されます。

おすすめ

転載: blog.csdn.net/sinat_39620217/article/details/132281264