学习笔记|Pytorch使用教程33(图像目标检测一瞥(上))

学习笔记|Pytorch使用教程33

本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2

  • 图像目标检测是什么?
  • 模型是如何完成目标检测的?
  • 深度学习目标检测模型简介
  • PyTorch中的Faster RCNN训练

一.图像目标检测是什么?

目标检测:判断图像中目标位置
目标检测两要素

  • 1.分类:分类向量[p0, … pn]
  • 2.回归:回归边界框[x1, y1, x2, y2]
    在这里插入图片描述
    测试代码:
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


if __name__ == "__main__":

    # path_img = os.path.join(BASE_DIR, "demo_img1.png")
    path_img = os.path.join(BASE_DIR, "demo_img2.png")

    # config
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])

    # 1. load data & model
    input_image = Image.open(path_img).convert("RGB")
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    # 2. preprocess
    img_chw = preprocess(input_image)

    # 3. to device
    if torch.cuda.is_available():
        img_chw = img_chw.to('cuda')
        model.to('cuda')

    # 4. forward
    input_list = [img_chw]
    with torch.no_grad():
        tic = time.time()
        print("input img tensor shape:{}".format(input_list[0].shape))
        output_list = model(input_list)
        output_dict = output_list[0]
        print("pass: {:.3f}s".format(time.time() - tic))
        for k, v in output_dict.items():
            print("key:{}, value:{}".format(k, v))

    # 5. visualization
    out_boxes = output_dict["boxes"].cpu()
    out_scores = output_dict["scores"].cpu()
    out_labels = output_dict["labels"].cpu()

    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(input_image, aspect='equal')

    num_boxes = out_boxes.shape[0]
    max_vis = 40
    thres = 0.5

    for idx in range(0, min(num_boxes, max_vis)):

        score = out_scores[idx].numpy()
        bbox = out_boxes[idx].numpy()
        class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]

        if score < thres:
            continue

        ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
                                   edgecolor='red', linewidth=3.5))
        ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.show()
    plt.close()



    # appendix
    classes_pascal_voc = ['__background__',
                       'aeroplane', 'bicycle', 'bird', 'boat',
                       'bottle', 'bus', 'car', 'cat', 'chair',
                       'cow', 'diningtable', 'dog', 'horse',
                       'motorbike', 'person', 'pottedplant',
                       'sheep', 'sofa', 'train', 'tvmonitor']

    # classes_coco
    COCO_INSTANCE_CATEGORY_NAMES = [
        '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
        'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
        'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
        'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
        'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
        'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
        'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
        'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]

输出:

input img tensor shape:torch.Size([3, 624, 1270])
pass: 13.661s
key:boxes, value:tensor([[2.1437e+01, 4.0840e+02, 5.6342e+01, 5.3993e+02],
        [2.7507e+02, 4.1659e+02, 3.1846e+02, 5.2799e+02],
        [3.3170e+02, 5.0658e+02, 3.8219e+02, 6.2113e+02],
        [1.0627e+03, 5.6276e+02, 1.1684e+03, 6.2371e+02],
        [8.8013e+02, 5.0102e+02, 9.3208e+02, 6.2317e+02],
        [2.9642e+02, 5.2642e+02, 3.4381e+02, 6.2200e+02],
        [1.5379e+02, 3.9273e+02, 1.9051e+02, 4.7901e+02],
        [5.2459e+02, 5.5500e+02, 5.9428e+02, 6.2307e+02],
        [4.3968e+02, 4.7425e+02, 4.9720e+02, 6.1554e+02],
        [9.6592e+02, 4.4677e+02, 1.0049e+03, 5.7215e+02],
        [1.0311e+03, 4.7703e+02, 1.0741e+03, 6.1917e+02],
        [7.1520e+02, 5.5527e+02, 7.6435e+02, 6.2250e+02],
        [1.9180e+02, 3.9129e+02, 2.1840e+02, 4.5502e+02],
        [5.9519e+02, 5.6863e+02, 6.5838e+02, 6.2400e+02],
        [9.2346e+02, 4.2539e+02, 9.6890e+02, 5.4164e+02],
        [8.4545e+02, 4.2685e+02, 8.8473e+02, 5.3350e+02],
        [5.5792e-01, 3.6247e+02, 1.9292e+01, 4.2037e+02],
        [7.8786e+02, 4.5473e+02, 8.3009e+02, 5.5746e+02],
        [5.9756e+02, 4.3980e+02, 6.4331e+02, 5.7260e+02],
        [7.5372e+02, 5.4176e+02, 8.4086e+02, 6.2388e+02],
        [1.0174e+03, 5.0093e+02, 1.0505e+03, 5.4634e+02],
        [6.8192e+02, 5.3671e+02, 7.2875e+02, 6.2382e+02],
        [8.1197e+02, 4.2305e+02, 8.4461e+02, 5.3224e+02],
        [7.5444e+02, 3.9091e+02, 7.9017e+02, 4.9864e+02],
        [5.3107e+02, 3.9075e+02, 5.6285e+02, 4.8725e+02],
        [1.1842e+03, 5.6935e+02, 1.2687e+03, 6.2372e+02],
        [9.0154e+02, 4.5109e+02, 9.1972e+02, 4.7041e+02],
        [8.9092e+02, 4.1312e+02, 9.2181e+02, 5.0889e+02],
        [4.9160e+02, 4.8394e+02, 5.1212e+02, 5.2896e+02],
        [7.1178e+02, 4.7320e+02, 7.4839e+02, 5.6364e+02],
        [1.1422e+03, 4.1846e+02, 1.1851e+03, 5.2725e+02],
        [1.1044e+03, 4.1391e+02, 1.1432e+03, 5.1564e+02],
        [4.8151e+02, 5.2476e+02, 4.9801e+02, 5.6309e+02],
        [9.6673e+02, 4.7047e+02, 9.9382e+02, 5.1017e+02],
        [1.5301e+02, 4.0614e+02, 1.7877e+02, 4.3976e+02],
        [3.3971e+02, 3.4014e+02, 3.6640e+02, 4.1000e+02],
        [1.1215e+01, 3.0503e+02, 2.5390e+01, 3.4648e+02],
        [5.6783e+02, 4.4656e+02, 6.0336e+02, 5.6767e+02],
        [1.0671e+03, 4.0842e+02, 1.1084e+03, 5.1374e+02],
        [7.0506e+02, 4.0975e+02, 7.3976e+02, 4.9427e+02],
        [1.1736e+03, 4.1151e+02, 1.2080e+03, 5.2507e+02],
        [2.5137e+02, 3.2949e+02, 2.7344e+02, 3.9106e+02],
        [1.6691e+02, 2.8140e+02, 1.8110e+02, 3.1285e+02],
        [3.0369e+02, 4.6951e+02, 3.1904e+02, 5.0108e+02],
        [1.3483e+02, 3.2507e+02, 1.5080e+02, 3.6978e+02],
        [1.0107e+03, 4.4341e+02, 1.0458e+03, 5.4823e+02],
        [9.8960e+02, 3.7219e+02, 1.0161e+03, 4.4772e+02],
        [8.3098e+02, 3.9532e+02, 8.5813e+02, 4.6472e+02],
        [6.6482e+02, 4.5117e+02, 6.8733e+02, 4.8071e+02],
        [3.6000e+02, 3.9332e+02, 3.8890e+02, 4.8298e+02],
        [1.0539e+03, 5.0596e+02, 1.0702e+03, 5.3395e+02],
        [4.6973e+02, 4.5248e+02, 5.0838e+02, 5.7020e+02],
        [1.5856e+02, 3.3735e+02, 1.7752e+02, 3.7869e+02],
        [1.5349e+02, 4.0826e+02, 1.7287e+02, 4.4103e+02],
        [3.8870e+02, 3.7187e+02, 4.2914e+02, 5.0503e+02],
        [9.6698e+02, 4.7291e+02, 9.9026e+02, 5.0854e+02],
        [5.5847e+02, 3.8495e+02, 5.8719e+02, 4.8738e+02],
        [4.9743e+02, 3.8822e+02, 5.2440e+02, 4.8499e+02],
        [6.0820e+01, 2.8248e+02, 7.6256e+01, 3.1529e+02],
        [6.8791e+02, 4.9479e+02, 7.2739e+02, 5.5288e+02],
        [6.5066e+02, 4.9294e+02, 7.0139e+02, 6.1894e+02],
        [2.0727e+02, 3.9253e+02, 2.2674e+02, 4.5656e+02],
        [3.3184e+02, 3.0833e+02, 3.4600e+02, 3.4902e+02],
        [1.0159e+03, 4.9606e+02, 1.0568e+03, 5.4730e+02],
        [6.0135e+01, 3.0983e+02, 7.7837e+01, 3.4432e+02],
        [6.3866e+02, 4.2136e+02, 6.7227e+02, 5.2106e+02],
        [4.6559e+02, 3.9241e+02, 4.8686e+02, 4.2766e+02],
        [5.6188e+01, 3.1174e+02, 7.0932e+01, 3.4426e+02],
        [4.3119e+02, 3.2678e+02, 4.6984e+02, 3.9337e+02],
        [5.9947e+02, 3.9388e+02, 6.3010e+02, 4.5644e+02],
        [1.1757e+03, 5.3985e+02, 1.2376e+03, 6.1111e+02],
        [6.6622e+02, 4.1731e+02, 7.0462e+02, 4.9162e+02],
        [1.7327e+02, 3.9034e+02, 1.9244e+02, 4.6006e+02],
        [4.7853e+02, 4.7409e+02, 5.1007e+02, 5.2950e+02],
        [2.8340e+02, 3.0262e+02, 2.9943e+02, 3.3402e+02],
        [7.4611e+02, 3.5429e+02, 7.7780e+02, 4.1244e+02],
        [7.4060e+02, 4.8190e+02, 7.6980e+02, 5.5989e+02],
        [9.6401e+02, 3.4509e+02, 9.8135e+02, 3.9366e+02],
        [4.1680e+02, 3.6824e+02, 4.4263e+02, 4.5963e+02],
        [8.7578e+02, 3.5269e+02, 8.9966e+02, 4.2111e+02],
        [1.0104e+03, 4.4886e+02, 1.0562e+03, 6.0378e+02],
        [3.0327e+02, 4.4166e+02, 3.1852e+02, 5.0216e+02],
        [4.4137e+02, 5.9189e+02, 4.8637e+02, 6.2351e+02],
        [1.9031e+02, 3.3614e+02, 2.1258e+02, 3.8856e+02],
        [1.8251e+02, 2.8241e+02, 1.9433e+02, 3.0937e+02],
        [1.2041e+03, 4.6998e+02, 1.2527e+03, 5.5469e+02],
        [1.0764e+03, 4.9217e+02, 1.1277e+03, 5.8620e+02],
        [1.0449e+03, 3.4939e+02, 1.0654e+03, 4.0458e+02],
        [1.0922e+03, 3.8381e+02, 1.1198e+03, 4.3049e+02],
        [5.1150e+02, 3.8435e+02, 5.3621e+02, 4.8211e+02],
        [3.1652e+02, 3.1374e+02, 3.3143e+02, 3.5101e+02],
        [9.4753e+02, 3.4059e+02, 9.6551e+02, 3.9753e+02],
        [5.1159e+02, 3.4381e+02, 5.3470e+02, 3.9003e+02],
        [5.8443e+02, 3.9833e+02, 6.1497e+02, 4.8086e+02],
        [7.4492e+02, 3.8018e+02, 7.6006e+02, 4.0452e+02],
        [1.1097e+03, 3.0709e+02, 1.1257e+03, 3.4568e+02],
        [6.6792e+02, 3.3659e+02, 6.8839e+02, 3.8192e+02],
        [3.0073e+02, 3.0162e+02, 3.1883e+02, 3.4768e+02],
        [1.0730e+03, 4.9895e+02, 1.1603e+03, 6.1907e+02],
        [3.9530e+02, 4.2951e+02, 4.2193e+02, 4.6254e+02]])
key:labels, value:tensor([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1, 31,  1,  1,  1,  1,  1, 31,  1, 27,  1,  1,  1, 31, 27, 27,  1,
         1,  1,  1,  1,  1,  1,  1, 31,  1,  1,  1,  1, 31,  1, 31,  1,  1, 31,
         1, 31,  1,  1,  1,  1,  1,  1,  1, 27,  1,  1, 31,  1,  1,  1,  1,  1,
         1, 27,  1,  1,  1,  1,  1,  1,  1, 31,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1, 31,  1,  1,  1,  1, 31])
key:scores, value:tensor([0.9860, 0.9852, 0.9780, 0.9779, 0.9774, 0.9739, 0.9500, 0.9464, 0.9456,
        0.9088, 0.8751, 0.8721, 0.8549, 0.8533, 0.8455, 0.8064, 0.8006, 0.7799,
        0.7588, 0.7488, 0.7124, 0.7113, 0.6831, 0.6695, 0.6562, 0.6551, 0.6532,
        0.6498, 0.6471, 0.6365, 0.6178, 0.5983, 0.5870, 0.5829, 0.5744, 0.5698,
        0.5638, 0.5590, 0.5522, 0.5413, 0.5313, 0.5283, 0.5203, 0.4811, 0.4558,
        0.4536, 0.4442, 0.4402, 0.4374, 0.4368, 0.4313, 0.4210, 0.4119, 0.4099,
        0.3986, 0.3920, 0.3912, 0.3827, 0.3754, 0.3654, 0.3584, 0.3502, 0.3496,
        0.3414, 0.3399, 0.3283, 0.3225, 0.3126, 0.3124, 0.3101, 0.3049, 0.3025,
        0.3005, 0.2963, 0.2946, 0.2830, 0.2799, 0.2790, 0.2783, 0.2782, 0.2772,
        0.2759, 0.2711, 0.2684, 0.2643, 0.2574, 0.2509, 0.2462, 0.2401, 0.2385,
        0.2353, 0.2311, 0.2245, 0.2233, 0.2224, 0.2205, 0.2173, 0.2157, 0.2141,
        0.2109])

在这里插入图片描述
Debug分析一下整个流程:

  • 1.获取图片:path_img = os.path.join(BASE_DIR, "demo_img2.png")
  • 2.加载模型torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True),并设置成测试模式
  • 3.把数据(图片)处理成模型输入的格式(张量):img_chw = preprocess(input_image),其shape为:input img tensor shape:torch.Size([3, 624, 1270])
  • 4.前向传播:output_list = model(input_list)
    查看output_list:
    在这里插入图片描述
    这里只使用了一张图像,所以len(output_list) = 1。每一个字典都有三部分组成:boxeslabelsscores
  • 5.获取第一个张图的输出结果output_dict。分别查看其属性:
    在这里插入图片描述在这里插入图片描述在这里插入图片描述
    也就是输出了100个boxes,对应每个boxes都有应该score和label。
  • 6.把结果分别保存到out_boxesout_scoresout_labels。打印out_scores
    在这里插入图片描述
    发现out_scores是按顺序排列的,所以为了保证效果,只取得分较高的boxes。这就是使用for idx in range(0, min(num_boxes, max_vis)):的原因。这样,有了boxes和对应的类别class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]],就可以可视化了。

二.模型是如何完成目标检测的?

将3D张量映射到两个张量

  • 1.分类张量: shape为[N, C+ 1]
  • 2.边界框张量: shape为[N, 4]

Recent Advances in Deep Learning for Object Detection》-2019
在这里插入图片描述
边界框数量N如何确定?
传统方法一一滑动窗策略
缺点:

  • 1.重复计算量大
  • 2.窗口大小难确定
    在这里插入图片描述
    在这里插入图片描述

利用卷积减少重复计算
在这里插入图片描述
重要概念:

  • 特征图一个像素对应原图一块区域
    在这里插入图片描述

三.深度学习目标检测模型简介

Object Detection in 23 Years- A Survey》 -2019
在这里插入图片描述
按流程分为: one- stagetwo-stage
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Faster RCNN——经典two stage检测网络
A. Faster RCNN 的 backbone structure 对图像进行特征提取,生成feature map。
B. feature map一部分会进入RPN网络,RNP网络会生成数十万各候选框,再使用非极大值抑制(NMS),挑选2000个proposals,这2000个候选框(proposals)会叠加在上一步生成的feature map上,进行“抠图”。生成子区域的特征图
C. 子区域特征图在经过ROI Layer,进行池化操作,生成统一固定大小的feature map。
D. 上一步生成的feature map会经过一系列的全链接,进行边界框回归(Regression)和类别分类(C+1 Softmax
E. 训练细节:生成的2000个候选框,会进一步筛选成512个,也就是输入到Stage2的子区域特征图是512个。
Faster RCNN数据流

  1. Feature map: [256,h f,w_f]
  2. 2 Softmax : [num_anchors, h_f, w_f]
  3. Regressors : [num_anchors*4, h_f, w_f]
  4. NMS OUT: [n_proposals = 2000, 4]
  5. ROI Layer: [512, 256, 7, 7]
  6. FC1 FC2: [512, 1024]
  7. c+1 Sofmax: [512, c + 1]
  8. Regressors: [512, (C+1)*4]

四.PyTorch中的Faster RCNN训练

学习笔记|Pytorch使用教程34(图像目标检测一瞥(下))

发布了76 篇原创文章 · 获赞 44 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_24739717/article/details/103392767