pytorch retinanet target detection

This article uses pytorch1.7.1, using the pre-training model provided by Pytorch, using this pre-training model, we can detect more than 80 kinds of objects in the COCO dataset.

Input format of RetinaNet

The format of the input image is [C, H, W], namely (channels, height, and width), and we also need to provide a batch size. Batch size refers to how many images are processed at a time. So the input image format is [N, C, H, W]. At the same time, the pixel value of the image should be between 0-1.

Output format of RetinaNet

It outputs a list containing a dictionary containing the resulting tensor. The format is List[Dict[Tensor]]. This Dict includes the following keys:

boxes (FloatTensor[N, 4]): The predicted boxes are in [x1, y1, x2, y2] format

labels (Int64Tensor[N]): predicted labels for each image

scores : (Tensor[N]): the score of each prediction

coco_names.py

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

detect_utils.py

In order to get the output, we need a simple function predict(), which receives 4 input parameters, image, model, device, detection_threshold

import torchvision.transforms as transforms
import cv2
import numpy as np
import torch
from coco_names import COCO_INSTANCE_CATEGORY_NAMES as coco_names

#不同类别的框不同的颜色 为91个类生成随机的RGBtuple
COLORS = np.random.uniform(0, 255, size=(len(coco_names), 3))
# define the torchvision image transforms 把图片变成张量
transform = transforms.Compose([
    transforms.ToTensor(),
])

def predict(image, model, device, detection_threshold):
    # transform the image to tensor
    image = transform(image).to(device)
    image = image.unsqueeze(0) # add a batch dimension
    with torch.no_grad():
        outputs = model(image) # get the predictions on the image
    # get all the scores
    scores = list(outputs[0]['scores'].detach().cpu().numpy())
    # index of those scores which are above a certain threshold
    thresholded_preds_inidices = [scores.index(i) for i in scores if i > detection_threshold]
    # get all the predicted bounding boxes
    bboxes = outputs[0]['boxes'].detach().cpu().numpy()
    # get boxes above the threshold score
    boxes = bboxes[np.array(scores) >= detection_threshold].astype(np.int32)
    # get all the predicited class names
    labels = outputs[0]['labels'].cpu().numpy()
    pred_classes = [coco_names[labels[i]] for i in thresholded_preds_inidices]
    return boxes, pred_classes

def draw_boxes(boxes, classes, image):
    for i, box in enumerate(boxes):
        color = COLORS[coco_names.index(classes[i])]
        cv2.rectangle(
            image,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 2
        )
        cv2.putText(image, classes[i], (int(box[0]), int(box[1]-5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, 
                    lineType=cv2.LINE_AA)
    return image


def predict(image, model, device, detection_threshold):

First pass the image into the RetinaNet model and get the output

with torch.no_grad():

outputs = model(image) 

Get all the scores in the output dictionary

scores = list(outputs[0]['scores'].detach().cpu().numpy())

Extract the corresponding index greater than threshold

thresholded_preds_inidices = [scores.index(i) for i in scores if i > detection_threshold]

Extract all candidate boxes in the output dictionary

bboxes = outputs[0]['boxes'].detach().cpu().numpy()

Filter out candidate boxes greater than threshold

boxes = bboxes[np.array(scores) >= detection_threshold].astype(np.int32)

def draw_boxes(boxes, classes, image):

boxes are the candidate boxes extracted by the previous filtering

classes are categories

The image is a frame and a classification name on the top

detect_images.py

import torchvision
import torch
import argparse
import cv2
import detect_utils
import numpy as np
from PIL import Image

# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', help='path to input image/video')
parser.add_argument('-m', '--min-size', dest='min_size', default=800, 
                    help='minimum input size for the RetinaNet network')
parser.add_argument('-t', '--threshold', default=0.6, type=float,
                    help='minimum confidence score for detection')
args = vars(parser.parse_args())
print('USING:')
print(f"Minimum image size: {args['min_size']}")
print(f"Confidence threshold: {args['threshold']}")

# download or load the model from disk
model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True, 
                                                            min_size=args['min_size'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the model onto the computation device
model.eval().to(device)

image = Image.open(args['input']).convert('RGB')
# a NumPy copy for OpenCV functions
image_array = np.array(image)
# convert to OpenCV BGR color format
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)

# get the bounding boxes and class labels
boxes, classes = detect_utils.predict(image, model, device, args['threshold'])
# get the final image
result = detect_utils.draw_boxes(boxes, classes, image_array)

cv2.imshow('Image', result)
cv2.waitKey(0)
save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{args['min_size']}_t{int(args['threshold']*100)}"
cv2.imwrite(f"outputs/{save_name}.jpg", result)

the function is not implemented
Rebuild the library with Windows, GTK+ 2.x or Cocoa support. 

Uninstall and download again

 

https://download.pytorch.org/whl/torch_stable.html

Guess you like

Origin blog.csdn.net/m0_67105022/article/details/123539027