[Tutorial] Start from scratch-PIDNet (semantic segmentation) model to train your own data set

introduction

Train your own dataset from scratch with the semantic segmentation model PIDNet.

PIDNet paper address: https://arxiv.org/pdf/2206.02066.pdf

PIDNet项目地址:GitHub - XuJiacong/PIDNet: This is the official repository for our recent work: PIDNet

1. Preparation of data set

First, let’s explain what kind of data set is needed: The semantic label image required by PIDNet is an 8-bit grayscale image (the same as the training format required by BiSeNet we wrote before), and the labels for semantic segmentation are represented by grayscale values. . Here we first show the overall format of a dataset folder, as shown in the figure below:

PV is the name of our data set. All we need to prepare are the four lst mapping files in the list folder, as well as the image (original image) and label (semantic segmentation map) files in the PV folder.

1. First, label the image through labelme, convert json into a segmented image, and convert the segmented image into an 8-bit grayscale image. These operations have been introduced in the previous blog. Please refer to my other blog for details. :

Tutorial - Use BiSeNet (semantic segmentation) network to train your own data set from scratch - Programmer Sought

 Just follow the above method to get an 8-bit grayscale image.

2. Just put the original image and 8-bit grayscale image obtained in the first step according to the folder format shown above. After placing it, we will use the following code to obtain the lst (mapping file). Be careful to modify the path of your own data set:

import os
def op_file():
    # train
    train_image_root = 'image/train/'
    train_label_root = 'label/train/'
    train_image_path = 'data/PV/image/train'
    train_label_path = 'data/PV/label/train'
 
    trainImageList = os.listdir(train_image_path)
    trainLabelList = os.listdir(train_label_path)
 
    train_image_list = []
    for image in trainImageList:
        train_image_list.append(train_image_root + image)
 
    train_label_list = []
    for label in trainLabelList:
        train_label_list.append(train_label_root + label)
 
    train_list_path = 'data/list/PV/train.lst'
    file = open(train_list_path, 'w').close()
    with open(train_list_path, 'w', encoding='utf-8') as f:
        for i1,i2 in zip(train_image_list, train_label_list):
            print(i1, i2)
            f.write(i1 + "   " + i2 + "\n")
    f.close()
 
    # test
    test_image_root = 'image/test/'
    test_label_root = 'label/test/'
    test_image_path = 'data/PV/image/test'
 
    testImageList = os.listdir(test_image_path)
 
    test_image_list = []
    for image in testImageList:
        test_image_list.append(test_image_root + image)
 
    test_list_path = 'data/list/PV/test.lst'
    file = open(test_list_path, 'w').close()
    with open(test_list_path, 'w', encoding='utf-8') as f:
        for i1 in test_image_list:
            f.write(i1 + "\n")
    f.close()
 
    # val
    val_image_root = 'image/val/'
    val_label_root = 'label/val/'
    val_image_path = 'data/PV/image/val'
    val_label_path = 'data/PV/label/val'
 
    valImageList = os.listdir(val_image_path)
    valLabelList = os.listdir(val_label_path)
 
    val_image_list = []
    for image in valImageList:
        val_image_list.append(val_image_root + image)
 
    val_label_list = []
    for label in valLabelList:
        val_label_list.append(val_label_root + label)
 
    val_list_path = 'data/list/PV/val.lst'
    file = open(val_list_path, 'w').close()
    with open(val_list_path, 'w', encoding='utf-8') as f:
        for (i1,i2) in zip(val_image_list, val_label_list):
            f.write(i1 + "   " + i2 + "\n")
    f.close()
 
    # trainval
    trainval_list_path = 'data/list/PV/trainval.lst'
    file = open(trainval_list_path, 'w').close()
    with open(trainval_list_path, 'w', encoding='utf-8') as f:
        for (i1,i2) in zip(train_image_list, train_label_list):
            f.write(i1 + "   " + i2 + "\n")
    f.close()
 
    with open(trainval_list_path, 'a', encoding='utf-8') as f:
        for (i1,i2) in zip(val_image_list, val_label_list):
            f.write(i1 + "   " + i2 + "\n")
    f.close()
 
if __name__ == '__main__':
    op_file()

2. Relevant code modifications

1. Copy the cityscapes.py of the same level directory under the datasets folder, and rename it to the name of our dataset PV.py, as shown below:

 Open PV.py, change all Cityscapes in it to PV (the name of your data set); modify num_classes=3 (the number of your categories, including background, the blogger here is three categories); modify mean and std; modify label_mapping (Write a few for several classes), modify class_weights (detailed calculation method is as follows)

 Modify the above mean, std and class_weights that need to calculate your own data set, just run the following code:

from random import shuffle
import numpy as np
import os
import cv2


def get_weight(class_num, pixel_count):
    W = 1 / np.log(pixel_count)
    W = class_num * W / np.sum(W)
    return W


def get_MeanStdWeight(class_num=3, size=(1080, 700)):
    image_path = "data/PV/image/train/"
    label_path = "data/PV/label/train/"

    namelist = os.listdir(image_path)
    """========如果提供的是txt文本,保存的训练集中的namelist=============="""
    # file_name = "../datasets/train.txt"
    # with open(file_name,"r") as f:
    #     namelist = f.readlines()
    #     namelist = [file[:-1].split(",") for file in namelist]
    """==============================================================="""

    MEAN = []
    STD = []
    pixel_count = np.zeros((class_num, 1))

    for i in range(len(namelist)):
        print(i, os.path.join(image_path, namelist[i]))

        image = cv2.imread(os.path.join(image_path, namelist[i]))[:, :, ::-1]
        image = cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
        print(image.shape)

        mean = np.mean(image, axis=(0, 1))
        std = np.std(image, axis=(0, 1))
        MEAN.append(mean)
        STD.append(std)

        label = cv2.imread(os.path.join(label_path, namelist[i]), 0)
        label = cv2.resize(label, size, cv2.INTER_LINEAR)

        label_uni = np.unique(label)
        for m in label_uni:
            pixel_count[m] += np.sum(label == m)

    MEAN = np.mean(MEAN, axis=0) / 255.0
    STD = np.mean(STD, axis=0) / 255.0

    weight = get_weight(class_num, pixel_count.T)
    print(MEAN)
    print(STD)
    print(weight)

    return MEAN, STD, weight

if __name__ == '__main__':
    get_MeanStdWeight()

2. Import the dataset we just created under the datasets/__init_.py file:

 3. Open the configs/cityscapes/pidnet_small_cityscapes.yaml file (the blogger chooses the smallest model here, you are free to do so), modify the training set name, data set path, number of categories and training model address:

 4. Open models/pidnet.py, modify the num_classes of PIDNet to your number of categories:

3. Start training

 The blogger uses a single GPU for training. Remember to modify the number of GPUs in the yaml file and execute the following code to start training:

python tools/train.py --cfg configs/cityscapes/pidnet_small_cityscapes.yaml

 The blogger did not encounter any errors. If you encounter errors, you can leave a message in the comment area, and the blogger will answer them one by one.

It should be noted that the PIDNet network will use multiple downsampling, so there are certain requirements for the size of the training images, otherwise there will be a problem of demoions mismatch. The blogger's training image size is 1080x640. The size of the training set can be changed by cropping. The cropping code is also available in my previous blog ( Tutorial - Using BiSeNet (Semantic Segmentation) Network to train your own data set from scratch_Computer Illusion Blog-CSDN Blog ).

4. Test

1. Picture test:

Before testing, you need to specify the loaded training model and modify it in the yaml file, as shown below:

Execute the code and start testing:

python tools/eval.py --cfg experiments/cityscapes/pidnet_small_cityscapes.yaml

The test results will be in the output folder, as shown below:

 Note: At this time, the test will find that the obtained image is black, which means that the final saved result is an 8-bit grayscale image, and what we need is a 24-bit RGB image. The solution:

Open the datasets/PV.py file again (the file where we define our own dataset), and add the color_list attribute, as follows:

 I have three categories here, so I just randomly wrote three colors (including the background), it can be based on your own needs. Then add the label2color function, as shown below:

 Code:

 def label2color(self, label):
        color_map = np.zeros(label.shape + (3,))
        for i, v in enumerate(self.color_list):
            color_map[label == i] = self.color_list[i]

        return color_map.astype(np.uint8)


    def save_pred(self, preds, sv_path, name):
        preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
        for i in range(preds.shape[0]):
            pred = self.label2color(preds[i])
            save_img = Image.fromarray(pred)
            save_img.save(os.path.join(sv_path, name[i]+'.png'))

 Test again and the output will be an RGB image.

2. Video test:

The source code does not provide a video test. The blogger provides one here. The code is as follows:

import os
import pprint
import sys
sys.path.insert(0, '.')
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import time
from PIL import Image
import numpy as np
import cv2
import logging
import lib.data.transform_cv2 as T
from utils.utils import create_logger
from configs import config
from configs import update_config
torch.set_grad_enabled(False)
import torch.backends.cudnn as cudnn
import models

# args

parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml')
parser.add_argument('--weight-path', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/output/PV/pidnet_small_cityscapes/best.pt')
parser.add_argument('--input', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/video.avi')
parser.add_argument('--output', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/PIDNet.mp4')
parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)


# fetch frames
def get_func(inpth, in_q, done):

    cap = cv2.VideoCapture(args.input)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # type is float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # type is float
    fps = cap.get(cv2.CAP_PROP_FPS)

    to_tensor = T.ToTensor(
        mean=(0.3257, 0.3690, 0.3223), # city, rgb
        std=(0.2112, 0.2148, 0.2115),
    )

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break
        frame = frame[:, :, ::-1]
        frame = to_tensor(dict(im=frame, lb=None))['im'].unsqueeze(0)
        in_q.put(frame)

    in_q.put('quit')
    done.wait()

    cap.release()
    time.sleep(1)
    print('input queue done')


# save to video
def save_func(inpth, outpth, out_q):
    np.random.seed(123)
    palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)

    cap = cv2.VideoCapture(args.input)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # type is float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # type is float
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()

    video_writer = cv2.VideoWriter(outpth,
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps, (int(width), int(height)))

    while True:
        out = out_q.get()
        if out == 'quit': break
        out = out.numpy()
        preds = palette[out]
        for pred in preds:
            video_writer.write(pred)
    video_writer.release()
    print('output queue done')


# inference a list of frames
def infer_batch(frames):
    frames = torch.cat(frames, dim=0).cuda()
    H, W = frames.size()[2:]
    frames = F.interpolate(frames, size=(768, 768), mode='bilinear',
            align_corners=False) # must be divisible by 32
    out = model(frames)[0]
    out = F.interpolate(out, size=(H, W), mode='bilinear',
            align_corners=False).argmax(dim=1).detach().cpu()
    out_q.put(out)


if __name__ == '__main__':

    # args = parse_args()

    logger, final_output_dir, _ = create_logger(
        config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = model = models.pidnet.get_seg_model(config, imgnet_pretrained=True)

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pt')

    logger.info('=> loading model from {}'.format(model_state_file))

    pretrained_dict = torch.load(model_state_file)
    if 'state_dict' in pretrained_dict:
        pretrained_dict = pretrained_dict['state_dict']
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                       if k[6:] in model_dict.keys()}
    for k, _ in pretrained_dict.items():
        logger.info(
            '=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    mp.set_start_method('spawn')

    in_q = mp.Queue(1024)
    out_q = mp.Queue(1024)
    done = mp.Event()

    in_worker = mp.Process(target=get_func,
                           args=(args.input, in_q, done))
    out_worker = mp.Process(target=save_func,
                            args=(args.input, args.output, out_q))

    in_worker.start()
    out_worker.start()
    model.eval()
    model = model.cuda()

    frames = []
    while True:
        frame = in_q.get()
        if frame == 'quit': break

        frames.append(frame)
        if len(frames) == 8:
            infer_batch(frames)
            frames = []
    if len(frames) > 0:
        infer_batch(frames)

    out_q.put('quit')
    done.set()

    out_worker.join()
    in_worker.join()





Modify your own file paths and execute the code:

 python demo_video.py --cfg /home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml

Just wait for a while. You can see the generated mp4 file in the test_dataset directory and open it. 

At this point, the PIDNet training tutorial is over. If you have any questions, you can leave a message and the blogger will answer them one by one.

Guess you like

Origin blog.csdn.net/qq_39149619/article/details/131931773