皆さんこんにちは, Wei Xue AI です. 今日はコンピュータ ビジョン 4 ターゲット検出タスクのアプリケーションを紹介します. Faster Rcnn+Resnet50+FPN モデルを使用してターゲットを予測します. ターゲット検出は、コンピューター ビジョンの 3 つの主要なタスク. Faster R-CNN はよく知られているターゲット検出ネットワークであり、主に領域提案ネットワーク (RPN) と Fast R-CNN の 2 つのモジュールに分かれています。ResNet50をベースネットワークとし、FPN(Feature Pyramid Network)を統合したFasterRCNNモデルについて詳しく紹介します。このモデルは次のように記述できます fasterrcnn_resnet50_fpn
。
今日はこの関数を実装します。誰もが操作でき、コードは直接実行されます。
1. モデル構造
1. ResNet50: ResNet は、残差ブロックを使用してトレーニング中の勾配消失の問題を解決する深い畳み込みニューラル ネットワークです。ResNet50 は、深さ 50 層の ResNet モデルを表します。このモデルは、生の画像から特徴を抽出する役割を果たします。
2.FPN: FPN は、ターゲット検出でさまざまなサイズのオブジェクトを処理するためにマルチスケールの特徴マップを生成する特徴処理アーキテクチャです。FPN は、畳み込みニューラル ネットワークの背後にレイヤーを追加して、異なる解像度の機能を融合させ、オブジェクト検出の精度を向上させます。
3.RPN : これは、FPN によって生成されたマルチスケールの特徴マップで動作する小さな畳み込みネットワークです。RPN の主な目的は、下流の Fast R-CNN のターゲット候補ボックス (Region of Interest、略して RoI) を生成することです。これはターゲット検出タスクの第 1 段階です.RPN はスライディング ウィンドウを使用して複数の候補ボックスを生成し、さまざまなスケールとアスペクト比のアンカー ポイントにバウンディング ボックスを生成します。
4. Fast R-CNN : このモジュールは、RPN によって生成された候補フレームを受け取り、RoI Align を使用して異なるスケールのフィーチャ ピラミッド マップからフィーチャを抽出し、完全に接続されたレイヤーを使用して分類とフレーム回帰を行います。Fast R-CNN は、検出されたオブジェクト クラスとその境界ボックスの位置を出力します。
2. モデル原理
ターゲット検出プロセス: 特徴抽出 (ResNet50) -> FPN -> RPN -> RoI -> Fast R-CNN。まず、ResNet50 は元の画像の特徴を抽出し、これらの特徴を FPN に渡します。次に、FPN は、さまざまなサイズのオブジェクトに対応するために、マルチスケールの特徴マップを生成します。次に、RPN は、フィーチャ ピラミッドによって生成されたマルチスケールのフィーチャ マップを操作して、一連の提案ボックスを生成します。RPN の出力は Fast R-CNN の入力として使用され、RoI を使用して候補フレームから特徴を抽出した後、結果は分類され、有界ボックス回帰になります。
例えば:
このモデルを自動運転シナリオ、歩行者、車、交通信号などの検出に使用するとします。カメラを使用して画像のフレームをキャプチャする場合、最初にこの画像を ResNet50 に入力します。これにより、その後のターゲット検出に役立つ特徴が抽出されます。その後、FPN はさまざまなスケールの特徴マップを生成することで、さまざまなサイズのオブジェクトの検出能力を向上させます。次に、RPN はこれらの機能マップから領域提案 (ボックス提案) を生成します。これらの候補ボックスには、対象となる可能性のある領域 (歩行者、車など) が含まれています。最後に、Fast R-CNN は RoI を使用して、さまざまなスケールの特徴マップから候補フレームの特徴を抽出します. 全結合層を処理した後、候補フレームを分類して境界回帰を実行し、最終的に検出結果を出力します. 自動運転のシナリオでは、モデルはカメラでキャプチャされた画像を分析することで、歩行者、車、信号機、その他の障害物を迅速かつ正確に検出できるため、車両が正しい判断を下すのに役立ちます。
3. コードの実装
import torchvision
from PIL import Image, ImageDraw, ImageFont
from coco_class import class_names
# 加载COCO数据集预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# 设置模型为评估模式
model.eval()
# 加载图像并进行预处理
image = Image.open('banana.png')
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
image_tensor = transform(image)
image_tensor = image_tensor[:3]
# 利用模型进行预测
predictions = model([image_tensor])
# 处理预测结果并输出
draw = ImageDraw.Draw(image)
font = ImageFont.truetype("arial.ttf", 30) # 设置字体大小和样式
for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']):
if score > 0.5:
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline='red')
label_name = class_names[label.item()]
draw.text((box[0], box[1]), str(label_name), fill='red', font=font) # 在图片上打印分类名称
image.show()
coco_class.py ファイルは、coco データセットにカテゴリをロードするためのものです。
class_names = {
0: 'background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
12: 'N/A',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
26: 'N/A',
27: 'backpack',
28: 'umbrella',
29: 'N/A',
30: 'N/A',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
45: 'N/A',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
66: 'N/A',
67: 'dining table',
68: 'N/A',
69: 'N/A',
70: 'toilet',
71: 'N/A',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
83: 'N/A',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
操作結果:
ここでは、ターゲットの位置情報とカテゴリ情報を特定し、ビデオを特定して分類できます。