Machine Learning Notes - Building Custom Object Detection with Detecto

1. Introduction to Detecto

        Detecto is a Python package that lets you build fully functional computer vision and object detection models in just 5 lines of code. Inference on still images and videos, transfer learning on custom datasets, and serializing models to files are just a few of Detecto's capabilities. Detecto is also built on PyTorch, allowing easy transfer of models between the two libraries.

        The power of Detecto is its simplicity and ease of use. Creating and running a pretrained Faster R-CNN ResNet-50 FPN from PyTorch's Model Zoo requires 4 lines of code:

from detecto.core import Model
from detecto.visualize import detect_video

model = Model()  # Initialize a pre-trained model
detect_video(model, 'input_video.mp4', 'output.avi')  # Run inference on a video

        You can install it using the pip command

pip install detecto

2. Build a custom object detection model

1. Dataset

        We use a dog data set here, the data set download address. This dataset has been tagged by the developers of Detecto, we just need to import it into our environment. The dataset contains 300 labeled images of golden retrievers and chihuahuas.

Link: https://pan.baidu.com/s/1oVXgh093jGnZ1_auQO3mSA 
Extraction code: htsp

        There are three folders in the dataset: images, train_labels, and val_labels, which contain image files, training set labels, and validation set labels, respectively.

2. Data preprocessing

(1) Convert labels and visualize images

from detecto import utils
import matplotlib.pyplot as plt
import matplotlib.image as img
from torchvision import transforms
from detecto import core
from detecto import visualize

utils.xml_to_csv('train_labels', 'train.csv')
utils.xml_to_csv('val_labels', 'val.csv')

image = img.imread('images/n02085620_8611.jpg')
plt.imshow(image)
plt.show()

 (2) Image transformation and visualization

transform_img = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(800),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    utils.normalize_transform(),
])

dataset = core.Dataset('train.csv', 'images/', transform=transform_img)

image, information = dataset[50]
visualize.show_labeled_image(image, information['boxes'], information['labels'])

3. Train the model and make predictions

        The Faster-RCNN-resnet50 model is used here. core.DataLoader downloads the resnet50 model by default, and currently there are two options, mobilenet_v3 and mobilenet_v3_320.

dataloader = core.DataLoader(dataset)
validation_data = core.Dataset('val.csv', 'images/')
categories = ['Chihuahua', 'golden_retriever']
classifier = core.Model(categories)
history = classifier.fit(dataloader, validation_data, epochs = 20, verbose = True)
plt.plot(history)


images = []
for i in range(0,36,3):
  image,_ = validation_data[i]
  images.append(image)

visualize.plot_prediction_grid(classifier, images, dim=(4, 3), figsize=(16, 12))

        training output 

          The predicted results are as follows 

 4. Summary

        The Detecto library is still very convenient to use, but there are still some problems with version compatibility. When running, it will find many versions, dll and other problems. If you are interested, you can take a look at its github.

https://github.com/alankbi/detectoicon-default.png?t=M3C8https://github.com/alankbi/detecto

Guess you like

Origin blog.csdn.net/bashendixie5/article/details/124282791