Model training and inference code for target recognition of dangerous goods/prohibited goods under x-ray

foreword

1. The importance of security checks in public places is self-evident, and ensuring the personal safety of the masses is its top priority. In various occasions, security checks are an indispensable link. As an important tool for security inspection, x-ray security inspection machine has the characteristics of manual monitoring and judgment imaging, but its limitations are also very obvious.
In order to solve this limitation as a starting point and let artificial intelligence intervene, the YOLO algorithm is used under the Torch framework to realize the target detection function of 10 types of dangerous objects such as knives, gas tanks, and firecrackers in X-ray images.
2. Source address: https://download.csdn.net/download/matt45m/88178088

data set

1. Dataset definition

The security check contraband in each scene is the same, but from the perspective of laws and regulations, contraband can be divided into 10 categories and more than 60 subcategories, as follows:

  1. Guns: sports guns, machine guns, riot guns, stun guns, starting guns, anesthesia injection guns, air guns, rifles, submachine guns, nail guns, fire guns, shotguns, toy guns, imitation guns, tear gas guns, prop guns, ball guns ,pistol
  2. Hardware tools: big wrench, iron hammer, shovel, baton, axe, arm strength device, throwing stick
  3. Pets: cats, hamsters, otters, snakes, hedgehogs, chinchillas, rabbits, dogs (except guide dogs), turtles, lizards
  4. Disinfection supplies: medical alcohol, 84 disinfectant, alcohol wash-free gel, alcohol spray, peracetic acid disinfectant, hydrogen peroxide disinfectant, alcohol disinfectant
  5. Drugs and corrosion: hydrochloric acid, selenium powder, pesticides, mercury, cyanide, arsenic, phenol, sulfuric acid, rat poison, nitric acid, potassium hydroxide, insecticides, sodium hydroxide
  6. Fuel, gas: kerosene, ethane, ethylene, natural gas, methane, diesel, propylene, acetylene, liquefied petroleum gas, butane, gasoline, hydrogen, carbon monoxide
  7. Flammables: red phosphorus, solid alcohol, white phosphorus, celluloid, magnesium aluminum powder, potassium, lithium, sodium, paint, thinner, oil paper, ether, yellow phosphorus, rosin oil, calcium carbide (calcium carbide), acetone, benzene, flash powder
  8. Food: unsealed durian, self-heating pot, live fish, live shrimp, self-heating rice, live crab, stinky tofu, self-heating hot pot
  9. Explosives: incendiary bombs, gas bombs, fuses, fuses, detonators, tear gas bombs, explosives, bombs, flares, grenades, signal flares, various fireworks, grenades, smoke bombs, pyrotechnic powder.
  10. Knives and Supplies: Utility Knife, Butcher Knife, Pointed Scissors, Ceramic Knife, Fruit Knife, Table Knife, Swiss Army Knife, Kitchen Knife.

2. Dataset collection

There are two forms of X-ray data, one is color imaging, and the other is black and white imaging. Compared with color imaging, black and white imaging is more difficult for human vision to separate contraband.

insert image description here
insert image description here

3. Data annotation

The data set uses open source data on the Internet and some privately collected data sets. The data set covers basically all contraband. However, due to the problem of manpower and computing power, only 4000 color X-ray images were selected and labeled Listed 10 common prohibited items in our country, namely: 'lighter', 'scissors', 'powerbank', 'pressure', 'knife', 'zippooil', 'handcuffs', 'slingshot', 'firecrackers',' nail polish'. The data is labeled LabelImg, and the label format is YOLO.
insert image description here
insert image description here

model training

The training and development environment is win10, graphics card RTX3080; cuda10.2, cudnn7.1; OpenCV4.5; yolov5 uses a 5s model, Anaconda 3.5.

1. Create an environment

 conda create --name yolov5 python=3.8
 activate yolov5
 git clone https://github.com/ultralytics/yolov5.git
 cd yolov5
 pip install -r requirements.txt

or

conda create --name yolov5 python=3.8
activate yolov5
git clone https://github.com/ultralytics/yolov5.git
cd yolov5
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install cython matplotlib tqdm opencv-python tensorboard scipy pillow onnx pyyaml pandas seaborn

2. Training

Open model/yolov5s.yaml and change the nc number.

# parameters
nc: 10  # 检测总类别
depth_multiple: 0.33  # model depth multiple 网络的深度系数
width_multiple: 0.50  # layer channel multiple 卷积核的系数

# anchors 候选框,可以改成自己目标的尺寸,也可以增加候选框
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 backbone
backbone: #特征提取模块
  # [from, number, module, args]
  # from - 输入是什么,-1:上一层的输出结果;
  # number - 该层的重复的次数,要乘以系数,小于1则等于1 源码( n = max(round(n * gd), 1) if n > 1 else n)
  # module - 层的名字
  # args - 卷积核的个数
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2  # 64要乘以卷积核的个数 64*0.5 = 32个特征图
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, BottleneckCSP, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 9, BottleneckCSP, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, BottleneckCSP, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 3, BottleneckCSP, [1024, False]],  # 9
  ]

# YOLOv5 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, BottleneckCSP, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, BottleneckCSP, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, BottleneckCSP, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, BottleneckCSP, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5) [17,20,23] #17层、20层、23层;
  ]

Add a dangerous.yaml training data configuration file in the data directory, the file content is as follows:

# download command/URL (optional)
download: bash data/scripts/get_voc.sh

# 训练集txt与验证集txt路径
train: data/xxx/train.txt
val: data/xxx/val.txt

# 总类别数
nc: 10

# 类名
names: ['lighter','scissors','powerbank','pressure','knife','zippooil','handcuffs','slingshot','firecrackers','nailpolish']

start training
training command

Single card:

python train.py --cfg models/yolov5s.yaml --data data/ODID.yaml --hyp data/hyps/hyp.scratch.yaml --epochs 100 --multi-scale --device 0

Doka:

python train.py --cfg models/yolov5s.yaml --data data/ODID.yaml --hyp data/hyps/hyp.scratch.yaml --epochs 100 --multi-scale --device 0,1

test model

python test.py --weights runs/train/exp/weights/best.pt --data data/ODID.yaml --device 0 --verbose
--weights: 训练得到的模型
--data:数据配置文件.txt
--device:选择gpu进行评测
--verbose:是否打印每一类的评测指标

model reasoning

For simple project requirements, Gradio can be used for rapid deployment. However, when the complexity of the project increases and the interface layout becomes more complicated, it will become very difficult to use Gradio to manage and set the position of each control. In this case, you can consider using Qt. PyQt is a Python-bound Qt framework that simplifies the project development process by utilizing Qt's layout and design tools to convert drag-and-drop and set-up interfaces into Python code.
py qt5 install

pip install PyQt5
pip install PyQt5-tools

After installing PyQt5, you can start writing inference codes. pyqt can quickly create a one-inference UI interface. The project is divided into three types of input, image, camera, and video.

#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
Run a YOLO_v3 style detection model on test images.
"""
import numpy as np
import argparse

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression, scale_coords
from utils.plots import plot_one_box
from utils.torch_utils import select_device
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
flag = False

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
    # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
    shape = img.shape[:2]  # current shape [height, width]
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better test mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, 32), np.mod(dh, 32)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratio, (dw, dh)
# with torch.no_grad():
#     detect()
class Ui_MainWindow(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(Ui_MainWindow, self).__init__(parent)
        self.timer_camera = QtCore.QTimer()
        self.timer_camera_capture = QtCore.QTimer()
        self.cap = cv2.VideoCapture()
        self.CAM_NUM = 0
        self.set_ui()
        self.slot_init()
        # self.detect_image(self.image)
        self.__flag_work = 0
        self.x = 0
        parser = argparse.ArgumentParser()
        parser.add_argument('--weights', nargs='+', type=str, default='weights/dangerous-best.pt', help='model.pt path(s)')
        parser.add_argument('--source', type=str, default='images', help='source')  # file/folder, 0 for webcam
        parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
        parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
        parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
        parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
        parser.add_argument('--view-img', action='store_true', help='display results')
        parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
        parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
        parser.add_argument('--save-dir', type=str, default='results', help='directory to save results')
        parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
        parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
        parser.add_argument('--augment', action='store_true', help='augmented inference')
        parser.add_argument('--update', action='store_true', help='update all models')
        self.opt = parser.parse_args()
        print(self.opt)
        ut, source, weights, view_img, save_txt, imgsz = \
            self.opt.save_dir, self.opt.source, self.opt.weights, self.opt.view_img, self.opt.save_txt, self.opt.img_size
        webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt')
        self.device = select_device(self.opt.device)
        # if os.path.exists(out):  # output dir
        #     shutil.rmtree(out)  # delete dir
        # os.makedirs(out)  # make new dir
        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA

        # Load model
        self.model = attempt_load(weights,device=self.device)  # load FP32 model
        self.imgsz = check_img_size(imgsz, s=self.model.stride.max())  # check img_size
        if self.half:
            self.model.half()  # to FP16

        cudnn.benchmark = True  # set True to speed up constant image size inference

        # Get names and colors
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(self.names))]


    def set_ui(self):

        self.__layout_main = QtWidgets.QHBoxLayout()
        self.__layout_fun_button = QtWidgets.QVBoxLayout()
        self.__layout_data_show = QtWidgets.QVBoxLayout()

        self.openimage = QtWidgets.QPushButton(u'图片')
        self.opencameras = QtWidgets.QPushButton(u'摄像头')
        self.train = QtWidgets.QPushButton(u'视频')

        # self.Openvideo = QtWidgets.QPushButton(u'打开视频')
        self.openimage.setMinimumHeight(50)
        self.opencameras.setMinimumHeight(50)
        self.train.setMinimumHeight(50)
        # self.Openvideo.setMinimumHeight(50)
        # self.lineEdit = QtWidgets.QLineEdit(self)  # 创建 QLineEdit
        # self.lineEdit.textChanged.connect(self.text_changed)
        # self.lineEdit.setMinimumHeight(50)

        self.openimage.move(10, 30)
        self.opencameras.move(10, 50)
        self.train.move(15,70)

        # 信息显示
        self.showimage = QtWidgets.QLabel()
        # self.label_move = QtWidgets.QLabel()
        # self.lineEdit.setFixedSize(70, 30)

        self.showimage.setFixedSize(641, 481)
        self.showimage.setAutoFillBackground(False)

        self.__layout_fun_button.addWidget(self.openimage)
        self.__layout_fun_button.addWidget(self.opencameras)
        self.__layout_fun_button.addWidget(self.train)
        # self.__layout_fun_button.addWidget(self.Openvideo)

        self.__layout_main.addLayout(self.__layout_fun_button)
        self.__layout_main.addWidget(self.showimage)

        self.setLayout(self.__layout_main)
        # self.label_move.raise_()
        self.setWindowTitle(u'X光下目标识别0.1版本')

    def slot_init(self):
        self.openimage.clicked.connect(self.button_open_image_click)
        self.opencameras.clicked.connect(self.button_opencameras_click)
        self.timer_camera.timeout.connect(self.show_camera)
        # self.timer_camera_capture.timeout.connect(self.capture_camera)
        self.train.clicked.connect(self.button_train_click)
        # self.Openvideo.clicked.connect(self.Openvideo_click)


    def button_open_image_click(self):
        imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg;;*.png;;All Files(*)")
        img = cv2.imread(imgName)
        print(imgName)
        showimg = img
        with torch.no_grad():
            img = letterbox(img, new_shape=self.opt.img_size)[0]

            # Convert
            img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(self.device)
            img = img.half() if self.half else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            # Inference
            pred = self.model(img, augment=self.opt.augment)[0]

            # Apply NMS
            pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms)
            # Process detections
            for i, det in enumerate(pred):  # detections per image
                if det is not None and len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_coords(img.shape[2:], det[:, :4], showimg.shape).round()

                    # Write results
                    for *xyxy, conf, cls in reversed(det):
                        label = '%s %.2f' % (self.names[int(cls)], conf)
                        plot_one_box(xyxy, showimg, label=label, color=self.colors[int(cls)], line_thickness=3)
        self.result = cv2.cvtColor(showimg, cv2.COLOR_BGR2BGRA)
        self.result = cv2.resize(self.result, (640, 480), interpolation=cv2.INTER_AREA)
        self.QtImg = QtGui.QImage(self.result.data, self.result.shape[1], self.result.shape[0],
                                  QtGui.QImage.Format_RGB32)
        # 显示图片到label中;
        self.showimage.setPixmap(QtGui.QPixmap.fromImage(self.QtImg))

    def button_train_click(self):
        global flag
        self.timer_camera_capture.stop()
        self.cap.release()
        if flag == False:
            flag = True
            imgName, imgType = QFileDialog.getOpenFileName(self, "打开视频", "", "*.mp4;;*.avi;;All Files(*)")
            flag = self.cap.open(imgName)
            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"打开视频失败",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)
            else:
                self.timer_camera.start(30)
                self.train.setText(u'关闭识别')
        else:
            flag = False
            self.timer_camera.stop()
            self.cap.release()
            self.showimage.clear()
            self.train.setText(u'打开视频')
    def button_opencameras_click(self):
        self.timer_camera_capture.stop()
        self.cap.release()
        if self.timer_camera.isActive() == False:
            flag = self.cap.open(self.CAM_NUM)
            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)
            else:
                self.timer_camera.start(30)

                self.opencameras.setText(u'关闭识别')
        else:
            self.timer_camera.stop()
            self.cap.release()
            self.showimage.clear()
            self.opencameras.setText(u'打开摄像头')

    def show_camera(self):
        flag, img = self.cap.read()
        if img is not None:
            showimg = img
            with torch.no_grad():
                img = letterbox(img, new_shape=self.opt.img_size)[0]
                # Convert
                img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
                img = np.ascontiguousarray(img)
                img = torch.from_numpy(img).to(self.device)
                img = img.half() if self.half else img.float()  # uint8 to fp16/32
                img /= 255.0  # 0 - 255 to 0.0 - 1.0
                if img.ndimension() == 3:
                    img = img.unsqueeze(0)

                # Inference
                pred = self.model(img, augment=self.opt.augment)[0]

                # Apply NMS
                pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes,
                                           agnostic=self.opt.agnostic_nms)
                # Process detections
                for i, det in enumerate(pred):  # detections per image
                    if det is not None and len(det):
                        # Rescale boxes from img_size to im0 size
                        det[:, :4] = scale_coords(img.shape[2:], det[:, :4], showimg.shape).round()

                        # Write results
                        for *xyxy, conf, cls in reversed(det):
                            label = '%s %.2f' % (self.names[int(cls)], conf)
                            print(label)
                            plot_one_box(xyxy, showimg, label=label, color=self.colors[int(cls)], line_thickness=3)
            show = cv2.resize(showimg, (640, 480))
            self.result = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)
            showImage = QtGui.QImage( self.result.data,  self.result.shape[1],  self.result.shape[0], QtGui.QImage.Format_RGB888)
            self.showimage.setPixmap(QtGui.QPixmap.fromImage(showImage))
        else:
            flag = False
            self.timer_camera.stop()
            self.cap.release()
            self.showimage.clear()
            self.train.setText(u'打开视频')


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    ui = Ui_MainWindow()
    ui.show()
    sys.exit(app.exec_())

Then run main.py
insert image description here
insert image description here

postscript

1. Among the contraband, the dangerous goods to be detected are small lighters. Considering their easy occlusion in complex and crowded environments, sometimes it is difficult to detect them accurately. I used the S model, and the detection effect is not bad. Yes, if you have higher requirements for accuracy, in addition to increasing the training data, you can also choose a larger model or use YOLOv8.
2. In the security inspection scenario, missed detection is more serious than the problem triggered by false detection. In order to optimize the missed detection rate, you can appropriately increase the confidence level and add some similar samples, as well as scene negative samples.
3. If you are interested in this project or encounter any errors during the installation process, you can add my penguin group: 487350510, and we will discuss together.

Guess you like

Origin blog.csdn.net/matt45m/article/details/132129122