SAM (Segment Anything Model) は、Meta の研究者チームによって作成およびトレーニングされた深層学習モデルです。2023 年 4 月 5 日に発表された研究論文で発表されたこのイノベーションは、すぐに広く世間の関心を呼び起こし、関連する Twitter スレッドは現在までに 350 万回以上の閲覧を集めています。
コンピューター ビジョンの専門家は現在 SAM に注目していますが、なぜでしょうか?
推奨: NSDT エディタを使用して、プログラム可能な 3D シーンをすばやく構築します
1.SAMとは何ですか?
セグメントすべての研究論文では、SAM は「ベース モデル」と呼ばれています。
基本モデルは、より具体的なタスクで使用および再トレーニングすることを目的として、大量のデータでトレーニングされた機械学習モデル (通常は自己教師あり学習または半教師あり学習を介して) です。
言い換えれば、SAM は、他のタスク (特に微調整を通じて) に適応するように設計された事前トレーニング済みモデルです。
たとえば、SAM を再トレーニングして、データセット内の人物のみをセグメント化するために使用できます。
人物のセグメンテーションは、SAM がそのようなオブジェクトを含むデータセットでトレーニングされているため、実行できる補助タスクですが、それだけではありません。
2. SAM はどのようにトレーニングされますか?
SAM は、Segment Anything 研究論文と並行して Meta によって導入された SA-1B データセットでトレーニングされました。
Facebook の親会社のデータセットには、ほぼ地球全体から収集された 1,100 万枚以上の画像が含まれており、これは一般化する機能を備えたモデルを開発する上で重要な側面です。
ほぼ地球全体から収集された画像 – SA-1B データセット
これらの高品質画像 (平均 1500 × 2250 ピクセル) には、データセット ラベルに対応する 11 億個のセグメンテーション マスクが付属しています。
Meta がこのデータセットを使用する目的は、AI 博士号取得者向けのセグメンテーション参照を作成することです。研究目的であれば正式に無料でライセンスされています。
非常に有益ですが、マスクはカテゴリに依存しないことに注意してください。言い換えれば、SAM が人のマスクを生成できたとしても、そのマスクが人を表していることを示すことはできません。
これは、SAM を実際に活用するには他のアルゴリズムと組み合わせる必要があることを意味するため、留意すべき重要な点です。
詳しく見てみましょう。
3. SAM の使用方法は?
まず、2 つの項目をロードする必要があります。
- SAM を使用するためのクラスと関数を含む GitHub フォルダーのセグメント
- メタ研究者によって取得されたモデル バージョンを使用して事前トレーニングされたモデルの重み
!pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
次に、3 つのグローバル変数を作成します。
- MODEL_TYPE: 使用する SAM アーキテクチャ
- CHECKPOINT_PATH: モデルの重みを含むファイルへのパス
- デバイス: 使用されているプロセッサー、「cpu」または「cuda」 (GPU が使用可能な場合)
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "/content/sam_vit_h_4b8939.pth"
DEVICE = "cuda" #cpu,cuda
これで、sam_model_registry 関数を使用して SAM モデルをロードし、モデルの重みを示すことができます。
from segment_anything import sam_model_registry
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
モデルがロードされた後、Meta は 2 つの使用オプションを提供します。
- ジェネレーター オプション。モデルによって生成されたすべてのマスクを画像から取得できます。
- 予測オプション。ヒントに基づいて画像から 1 つ以上の特定のマスクを取得できます。
次の数行で両方のオプションについて説明します。
その前に、モデルを実験する画像をインターネットから読み込みましょう。
from urllib.request import urlopen
import cv2
import numpy as np
from google.colab.patches import cv2_imshow
resp = urlopen('https://images.unsplash.com/photo-1615948812700-8828458d368a?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=2072&q=80')
image = np.asarray(bytearray(resp.read()), dtype='uint8')
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
image = cv2.resize(image, (int(image.shape[1]/2.5), int(image.shape[0]/2.5)))
cv2_imshow(image)
画像には数人の人物、犬、数台の車が含まれています。
次に、SAM およびジェネレーター オプションを使用して画像をセグメント化します。
4. 発電機
このセクションでは、SAM のジェネレーター バージョンを使用します。これにより、モデルによる画像の分析の結果として生成されたマスクのセットを取得できるようになります。
SamAutomaticMaskGenerator オブジェクトを初期化しましょう。
from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(sam)
次に、generate() 関数を使用してマスクの生成を開始します。
masks_generated = mask_generator.generate(image)
この関数は、検出されたオブジェクトごとに他のデータとともにマスクを生成します。SAM は実際に、検出したオブジェクトに関連する一連の情報を (辞書の形式で) 生成します。
5. 予測結果
情報セットごとに取得されたキーを表示できます。
print(masks_generated[0].keys())
出力:
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
結果は 7 つのメッセージのセットになります。最初の「セグメント」は、検出されたオブジェクトの位置に対応するピクセルを表します。ピクセルにオブジェクトが含まれている場合は True、そうでない場合は False。
マスクは次のように表示できます。
cv2_imshow(masks_generated[3]['segmentation'].astype(int)*255)
このコレクションのその他の情報は、次の説明に対応します。
- エリア: マスクエリア(ピクセル単位)
- bbox: XYWH 形式のマスク境界ボックス
- Predicted_iou: モデルによって予測されたマスク品質スコア
- point_coords: このマスクを生成したサンプリングされた入力ポイント
- steady_score: 追加のマスク品質スコア
- Crop_box: このマスクを XYWH 形式で生成するために使用される画像のトリミング
ほとんどの実務者はこの情報を使用しませんが、特定のケースでは、SAM がマスクだけでなく、このような追加情報も生成することを知っておくことが重要です。
上記のマスクに関して取得された残りの情報は次のとおりです。
print('area :', masks_generated[3]['area'])
print('bbox :',masks_generated[3]['bbox'])
print('predicted_iou :',masks_generated[3]['predicted_iou'])
print('point_coords :',masks_generated[3]['point_coords'])
print('stability_score :',masks_generated[3]['stability_score'])
print('crop_box :',masks_generated[3]['crop_box'])
出力:
area : 5200 bbox : [499, 284, 92, 70]
predicted_iou : 1.005275845527649
point_coords : [[582.1875, 318.546875]]
stability_score : 0.981315553188324
crop_box : [0, 0, 828, 551]
SAM によって生成されたマスクの数を表示することもできます。
print(len(masks_generated))
出力:
111
SAM は画像から合計 111 個のマスクを生成しました。
6. 予測を表示する
この投稿で紹介したdraw_masks_fromDict関数を使用すると、画像上に生成されたすべてのマスクを描画できます。
segmented_image = draw_masks_fromDict(image, masks_generated)
cv2_imshow(segmented_image)
開始イメージには、SAM によって生成されたマスクが含まれています。
このセクションでは、SAM のジェネレーター バージョンを使用します。これにより、画像から 111 個のマスクを生成できました。マスクに加えて、SAM は追加の検出情報を生成します。モデルの予測を視覚化するために、最後にすべてのマスクを開始画像にプロットします。
したがって、SAM を使用すると、画像のセグメンテーションを実行できます。ただし、生成されたマスクは順序付けされていないことがわかります。異なるマスクを区別するための分類がありません。たとえば、人々のマスクは単一の色に関連付けられていません。したがって、結果のセグメントを並べ替えることはできません。ここで取得される情報は、オブジェクトの位置と境界だけです。
また、生成されたマスクは重複する可能性があります。実際、SAM は他のオブジェクト内のオブジェクトを検出できます。良い面としては、これは、SAM が画像内のほぼすべてのオブジェクトを検出できることを示しています。これは、犬、車、人、その他の物体 (車輪、窓、ズボンなど) をセグメント化できることを意味します。したがって、SAM のジェネレーター バージョンでは、重なっているオブジェクトも含めて、画像内のすべてのオブジェクトをセグメント化できます。
7. 発電機を超えて
ただし、この機能には欠点もあります。特定の領域での予測の数が増加し、目標の達成が損なわれる可能性があります。たとえば、画像内の人物を検出したい場合、そのジャケットとパンツに対応するマスクも検出しても問題ありません。
さらに、SAM はラベル付きデータでトレーニングされていないため、その予測をフィルターして目的の予測を保持することはできません。これは、SAM のジェネレーター バージョンを使用してデータセット内のすべての画像をセグメント化したとしても、たとえば人物のマスクを簡単に抽出することはできないことを意味します。したがって、画像内のすべてのオブジェクトをセグメント化する SAM ジェネレーターの機能は、一部の問題の解決には適していない可能性があります。
したがって、目標物体の検出には、ジェネレーター版の SAM を使用するのは適していません。代わりに、予測バージョンを使用する必要があります。このリリースでは、SAM を使用できるようになり、リクエストと計測するターゲット オブジェクトを指定するように求められます。
8. 予測者
このセクションでは、SAM の予測バージョンを使用します。プレディクター バージョンを使用すると、対象のオブジェクトを検出できるようになります。これを行うには、検出するオブジェクトを指定するための SAM ヒントを送信します。
現在、SAM にプロンプトを送信するには 2 つの方法があります。
- 興味のあるポイントごとに
- 境界ボックスによる
SAM は、オブジェクトを表す画像ピクセルの関心のある点 (x および y 座標) を入力として受け取ることができます。対象のポイントによって指定されたオブジェクトにより、SAM はこのオブジェクトに関連付けられたマスクを生成できるようになります。
SAM は、画像内のオブジェクトの輪郭を区切る境界ボックスを入力として受け取ることもできます。これらのアウトラインに基づいて、SAM は適切なマスクを生成します。
注: 「プロンプト」は、ほとんどの場合、ChatGPT に送信されるテキスト リクエストを指すために使用される一般的な用語です。ただし、SAM で示されているように、ヒントはテキスト リクエストに限定されません。これは、実践者が機械学習モデルに送信できる一連のクエリにまで及びます。
この機能は現在公開されていませんが、Meta はモデルの細分化を通じてテキスト リクエストの理解をすでに条件付けしていることに注意することが重要です。
ただし、このチュートリアルの残りの部分では、プロンプトを SAM に送信する必要があります。境界ボックスはコンピュータ ビジョンの標準なので、それを使用します。
9. 境界ボックスのヒントを使用する
このチュートリアルを続行する場合は、まずセグメント化するオブジェクトにバウンディング ボックスを関連付ける必要があります。
画像の境界ボックスがない場合は、YOLO テンプレートを使用して数行のコードで簡単に境界ボックスを生成できます。
このテンプレートを使用して、独自の境界ボックスをすばやく生成する方法を学習できます。YOLO の最新バージョンに特化したチュートリアルがここであなたを待っています。
画像に YOLO を使用すると、次のような結果が得られます。
image_bboxes = image.copy()
boxes = np.array(results[0].to('cpu').boxes.data)
plot_bboxes(image_bboxes, boxes, score=False)
注: 結果変数は、モデルによって予測された結果です。
YOLO を使用して取得された境界ボックスは次の形式になります。
print(boxes)
出力:
[[ 495.96 285.65 589.8 356.48 0.89921 2]
[ 270.63 147.99 403.17 496.82 0.79781 0]
…
[ 235.32 279.23 508.93 399.63 0.3193 2]
[ 612.13 303.94 647.61 333.11 0.2854 2]]
最初の 4 つの値は境界ボックスの座標を表し、5 番目の値は予測された境界ボックスの信頼スコアを表し、6 番目の値は検出されたクラスを表します。
ヒントが得られたので、SamPredictor オブジェクトを初期化しましょう。
from segment_anything import SamPredictor
mask_predictor = SamPredictor(sam)
次に、SAM で分析する画像を指定します。
mask_predictor.set_image(image)
ここから、チュートリアルは 2 つの部分に分かれています。
- 単一オブジェクトの検出
- バッチオブジェクト検出
最初のオプションから始めましょう。
10. 単一オブジェクトの検出
オブジェクトのマスクを予測するには、predict() 関数でオブジェクトに対応する境界ボックスを Predictor に伝えます。
mask, _, _ = mask_predictor.predict(
box=boxes[1][:-2]
)
検出されたオブジェクトの位置を示すブール配列の形式でマスクを取得します (辞書の「セグメンテーション」キーで前述したように)。ピクセルにオブジェクトが含まれている場合は True、そうでない場合は False。
この投稿で説明されているdraw_mask関数を使用して、このマスクを画像上に描画できます。
私たちの突出部には、SAM によって検出されたマスクが含まれています。
SAM へのヒントのおかげで、オブジェクトのマスクを取得して画像に表示することができました。
次に、すべての境界ボックスに対応するマスクを検出する方法を見てみましょう。
11. 複数のオブジェクトを検出する
一連の境界ボックスに対して予測を行うには、それらを PyTorch テンソルに収集する必要があります。
次に、transform.apply_boxes_torch() を使用してオブジェクトを更新します。
最後に、predict_torch を使用して、対応するマスクを予測します。
import torch
input_boxes = torch.tensor(boxes[:, :-2], device=mask_predictor.device)
transformed_boxes = mask_predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = mask_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
結果は、1 次元 (1、551、828) でエンコードされた 13 個のマスクのバッチになります。
このテンソルをより適切に操作するために、最初の無関係な次元を削除しましょう。
print(masks.shape)
masks = torch.squeeze(masks, 1)
print(masks.shape)
出力:
torch.Size([13, 1, 551, 828])
torch.Size([13, 551, 828])
SAM の上流でバウンディング ボックスを使用する利点は、生成された各マスクをバウンディング ボックスに対応するラベルに関連付けることができるため、表示時に色を使用して区別できることです。
YOLO が予測できるクラスに関連付けられた色のグラデーションを定義してみましょう。
COLORS = [(89, 161, 197),(67, 161, 255),(19, 222, 24),(186, 55, 2),(167, 146, 11),(190, 76, 98),(130, 172, 179),(115, 209, 128),(204, 79, 135),(136, 126, 185),(209, 213, 45),(44, 52, 10),(101, 158, 121),(179, 124, 12),(25, 33, 189),(45, 115, 11),(73, 197, 184),(62, 225, 221),(32, 46, 52),(20, 165, 16),(54, 15, 57),(12, 150, 9),(10, 46, 99),(94, 89, 46),(48, 37, 106),(42, 10, 96),(7, 164, 128),(98, 213, 120),(40, 5, 219),(54, 25, 150),(251, 74, 172),(0, 236, 196),(21, 104, 190),(226, 74, 232),(120, 67, 25),(191, 106, 197),(8, 15, 134),(21, 2, 1),(142, 63, 109),(133, 148, 146),(187, 77, 253),(155, 22, 122),(218, 130, 77),(164, 102, 79),(43, 152, 125),(185, 124, 151),(95, 159, 238),(128, 89, 85),(228, 6, 60),(6, 41, 210),(11, 1, 133),(30, 96, 58),(230, 136, 109),(126, 45, 174),(164, 63, 165),(32, 111, 29),(232, 40, 70),(55, 31, 198),(148, 211, 129),(10, 186, 211),(181, 201, 94),(55, 35, 92),(129, 140, 233),(70, 250, 116),(61, 209, 152),(216, 21, 138),(100, 0, 176),(3, 42, 70),(151, 13, 44),(216, 102, 88),(125, 216, 93),(171, 236, 47),(253, 127, 103),(205, 137, 244),(193, 137, 224),(36, 152, 214),(17, 50, 238),(154, 165, 67),(114, 129, 60),(119, 24, 48),(73, 8, 110)]
最後に、この記事で開発したdraw_masks_fromList関数を使用してすべてのマスクを描画し、各ラベルを色に関連付けることができます。
segmented_image = draw_masks_fromList(image, masks.to('cpu'), boxes, COLORS)
cv2_imshow(segmented_image)
提供された境界ボックスを使用して、YOLO によって予測されたすべてのマスクを表示します。さらに、各マスクは、境界ボックスで示されるクラスに従って色付けされます。これにより、さまざまなセグメンテーション オブジェクトを簡単に区別できるようになります。