Scikit-Learn を使用してマルチラベル分類を実装し、機械学習を支援します

皆さん、こんにちは。機械学習タスクにおいて、分類は入力データに基づいてラベルを予測するために使用される教師あり学習方法です。たとえば、過去の特徴に基づいて誰かが販売オファーに興味があるかどうかを予測したい場合、利用可能なトレーニング データを使用して機械学習モデルをトレーニングすることで、入力データに対して分類タスクを実行できます。

私たちは通常、二値分類 (2 つのラベル) や多クラス分類 (3 つ以上のラベル) などの古典的な分類タスクに遭遇します。この場合、分類器をトレーニングし、モデルは利用可能なすべてのラベルからラベルを予測しようとします。分類に使用されるデータセットは次の画像のようになります。

写真

上の画像は、ターゲット (販売オファー) にバイナリ分類の 2 つのラベルとマルチクラス分類の 3 つのラベルが含まれていることを示しています。モデルは利用可能な特徴からトレーニングされ、ラベルを 1 つだけ出力します。

マルチラベル分類は、バイナリ分類やマルチクラス分類とは異なります。マルチラベル分類では、1 つの出力ラベルを予測するだけではなく、逆に、マルチラベル分類では、出力ラベルに適用できる可能な限り多くのラベルを予測しようとします。入力データの場合、出力はタグなしから利用可能なタグの最大数まで可能です。

マルチラベル分類はテキスト データの分類タスクによく使用されます。以下はマルチラベル分類のデータセットの例です。

上記の例では、テキスト1〜テキスト5が、イベント、スポーツ、ポップカルチャー、自然の4つのカテゴリーに分類できる文章であるとする。上記のトレーニング データを使用して、マルチラベル分類タスクは、どのラベルが特定の文に適用されるかを予測します。各カテゴリは相互に排他的ではないため、互いに対立するものではなく、各ラベルは独立していると考えることができます。

さらに詳しく見ると、テキスト 1 にはスポーツとポップ カルチャーのタグが付けられ、テキスト 2 にはポップ カルチャーと自然のタグが付けられていることがわかります。これは、各ラベルが相互に排他的であり、複数ラベル分類の予測出力がラベルなし、または同時にすべてのラベルのいずれかになる可能性があることを示しています。

上記の説明を踏まえて、Scikit-Learn を使用してマルチラベル分類器を構築してみましょう。

Scikit-Learn を使用したマルチラベル分類

この記事では、Kaggle で公開されている生物医学 PubMed マルチラベル分類データセットを使用します。これにはさまざまな特徴が含まれていますが、この記事では、abstractText 特徴とその MeSH 分類 (A: 解剖学、B: 生物、C: 疾患など) のみを使用します。サンプルデータを以下に示します。

[生物医学 PubMed マルチラベル分類データセット]: https://www.kaggle.com/datasets/owaiskhan9654/pubmed-multilabel-text-classification

上記のデータセットは、各論文が複数のカテゴリに分類できることを示しています。これは、マルチラベル分類のケースです。このデータ セットを使用すると、Scikit-Learn を使用してマルチラベル分類器を構築できます。モデルをトレーニングする前に、まずデータ セットを準備します。

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

df = pd.read_csv('PubMed Multi Label Text Classification Dataset Processed.csv')
df = df.drop(['Title', 'meshMajor', 'pmid', 'meshid', 'meshroot'], axis =1)

X = df["abstractText"]
y = np.asarray(df[df.columns[1:]])

vectorizer = TfidfVectorizer(max_features=2500, max_df=0.9)
vectorizer.fit(X)

上記のコードでは、Scikit-Learn モデルがトレーニング データを受け入れられるように、テキスト データが TF-IDF 表現に変換されます。さらに、チュートリアルを簡略化するために、この記事ではストップワードの削除など、データの前処理の手順を省略しています。

データ変換が完了したら、データ セットをトレーニング セットとテスト セットに分割します。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=101)
  
X_train_tfidf = vectorizer.transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

すべての準備が完了したら、マルチラベル分類器のトレーニングを開始します。Scikit-Learn では、MultiOutputClassifier オブジェクトを使用して、マルチラベル分類子モデルをトレーニングします。このモデルの背後にある戦略は、ラベルごとに分類子をトレーニングすることであり、各ラベルには独自の分類子があります。

この例では、ロジスティック回帰を使用し、MultiOutputClassifier を使用してそれをすべてのラベルに拡張します。

from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression

clf = MultiOutputClassifier(LogisticRegression()).fit(X_train_tfidf, y_train)

モデルを変更したり、MultiOutputClasiffier に渡すモデル パラメーターを調整したりすることができますので、要件に応じて管理してください。トレーニングが完了したら、モデルを使用してテスト データを予測します。

prediction = clf.predict(X_test_tfidf)
prediction

写真

予測結果は各 MeSH カテゴリのラベルの配列であり、各行は文を表し、各列はラベルを表します。

最後に、マルチラベル分類器を評価する必要があります。精度メトリックを使用してモデルを評価できます。

from sklearn.metrics import accuracy_score
print('Accuracy Score: ', accuracy_score(y_test, prediction))****

精度スコアは 0.145 です。

精度スコアの結果は 0.145 でした。これは、モデルが正しいラベルの組み合わせを予測できる確率が 14.5% 未満であることを示しています。ただし、精度スコアには、マルチラベル予測評価の欠点があります。精度スコアでは、各文のすべてのラベルが正確な位置に表示されることが必要です。そうでない場合は、間違っていると見なされます。

たとえば、予測の最初の行は、テスト データと 1 つのラベルだけ異なります。

写真

精度スコアの場合、ラベルの組み合わせが異なるため、これは誤った予測とみなされます。そのため、モデルのメトリック スコアは低くなります。

この問題を解決するには、ラベルの組み合わせではなくラベルの予測を評価する必要があります。この場合、ハミング損失評価指標を使用できます。ハミング損失は、ラベルの総数に対する誤った予測の比率を計算することによって計算されます。これは、ハミング損失は損失関数であり、スコアが低いほど優れているためです (0 は誤った予測がないことを意味し、1 はすべての予測が誤っていることを意味します)。

from sklearn.metrics import hamming_loss
print('Hamming Loss: ', round(hamming_loss(y_test, prediction),2))

ハミング損失は 0.13 です。

マルチラベル分類器のハミング損失モデルは 0.13 です。これは、モデルが独立して約 13% 誤って予測していることを意味します。つまり、各ラベルの予測は 13% 誤る可能性があります。

要約する

マルチラベル分類は機械学習タスクであり、その出力はラベルなし、または入力データが与えられた場合に考えられるすべてのラベルのいずれかになります。これは、ラベル出力が相互排他的なバイナリ分類やマルチクラス分類とは異なります。

Scikit-Learn の MultiOutputClassifier を使用すると、ラベルごとに 1 つの分類器をトレーニングするマルチラベル分類器を開発できます。モデルの評価に関しては、精度スコアが全体的な状況を正しく反映していない可能性があるため、ハミング損失メトリクスを使用することをお勧めします。

 

おすすめ

転載: blog.csdn.net/csdn1561168266/article/details/132383752