Fully Convolutional Network (FCN) in Action: Using FCN for Semantic Segmentation

Abstract: FCN classifies images at the pixel level, thus solving the problem of image segmentation at the semantic level.

This article is shared from Huawei Cloud Community " Full Convolutional Network (FCN) in Practice: Using FCN to Implement Semantic Segmentation ", author: AI Hao.

FCN classifies images at the pixel level, thus solving the problem of semantic segmentation at the semantic level. Different from the classic CNN that uses the fully connected layer after the convolutional layer to obtain a fixed-length feature vector for classification (full connection layer + softmax output), FCN can accept input images of any size, and use the deconvolution layer to the last convolutional layer. The feature map is upsampled to restore it to the same size as the input image, so that a prediction can be generated for each pixel, while retaining the spatial information in the original input image, and finally on the upsampled feature map. Pixel classification.

The following figure is a schematic diagram of the structure of the fully convolutional network (FCN) used in semantic segmentation:

Disadvantages of traditional CNN-based segmentation methods?

Traditional CNN-based segmentation methods: In order to classify a pixel, an image patch around the pixel is used as the input of the CNN for training and prediction. This method has several disadvantages:

1) The storage cost is high. For example, use a 15*15 image block for each pixel, and then continuously slide the window to input the image block into the CNN for category judgment. Therefore, the required storage space varies sharply with the number and size of the sliding window. rise;

2) Inefficiency, the adjacent pixel blocks are basically repeated, and the convolution is calculated one by one for each pixel block, and this calculation is largely repeated;

3) The size of the pixel block limits the size of the receptive area. Usually, the size of the pixel block is much smaller than the size of the entire image, and only some local features can be extracted, resulting in limited classification performance.
The fully convolutional network (FCN) recovers the category to which each pixel belongs from the abstract features. That is, it is further extended from image-level classification to pixel-level classification.

What has FCN changed?

For general classification CNN networks, such as VGG and Resnet, some fully connected layers are added at the end of the network, and the class probability information can be obtained after softmax. However, this probability information is 1-dimensional, that is, it can only identify the category of the entire image, and cannot identify the category of each pixel, so this full connection method is not suitable for image segmentation.

FCN proposes that the following full connections can be replaced by convolutions, so that a 2-dimensional feature map can be obtained, followed by the softmax layer to obtain the classification information of each pixel, thus solving the segmentation problem, as shown in the figure.

Disadvantages of FCN

(1) The results obtained are still not precise enough. Although 8 times upsampling is much better than 32 times, the result of upsampling is still relatively blurry and smooth, and is not sensitive to the details in the image.
(2) Classify each pixel without fully considering the relationship between pixels. The spatial regularization step used in common pixel-based segmentation methods is ignored, which lacks spatial consistency.

data set

The dataset in this example uses the PASCAL VOC 2012 dataset, which has twenty categories:

Person:person

Animal: bird, cat, cow, dog, horse, sheep

Vehicle:aeroplane, bicycle, boat, bus, car, motorbike, train

Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor

Download address: The PASCAL Visual Object Classes Challenge 2012 (VOC2012) (ox.ac.uk) .

The structure of the dataset:

VOCdevkit
    └── VOC2012
         ├── Annotations               所有的图像标注信息(XML文件)
         ├── ImageSets    
         │   ├── Action                人的行为动作图像信息
         │   ├── Layout                人的各个部位图像信息
         │   │
         │   ├── Main                  目标检测分类图像信息
         │   │     ├── train.txt       训练集(5717)
         │   │     ├── val.txt         验证集(5823)
         │   │     └── trainval.txt    训练集+验证集(11540)
         │   │
         │   └── Segmentation          目标分割图像信息
         │         ├── train.txt       训练集(1464)
         │         ├── val.txt         验证集(1449)
         │         └── trainval.txt    训练集+验证集(2913)
         ├── JPEGImages                所有图像文件
         ├── SegmentationClass         语义分割png图(基于类别)
         └── SegmentationObject        实例分割png图(基于目标)

The dataset includes object detection and semantic segmentation. We only need the dataset for semantic segmentation, so we can consider deleting the redundant pictures. The idea of ​​deleting:

1. Get the name of all images.

2. Get the names of all semantic segmentation masks.

3. Find the difference between the two, and then delete the name of the difference.

code show as below:

import glob
import os
image_all = glob.glob('data/VOCdevkit/VOC2012/JPEGImages/*.jpg')
image_all_name = [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_all]

image_SegmentationClass = glob.glob('data/VOCdevkit/VOC2012/SegmentationClass/*.png')
image_se_name= [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_SegmentationClass]
image_other=list(set(image_all_name) - set(image_se_name))
print(image_other)
for image_name in image_other:
    os.remove('data/VOCdevkit/VOC2012/JPEGImages/{}.jpg'.format(image_name))

code link

The code used in this example is from deep-learning-for-image-processing/pytorch_segmentation/fcn at master WZMIAOMIAO/deep-learning-for-image-processing (github.com)

There are also many other codes, this one is easier to understand!

In fact, there is a better image segmentation library: https://github.com/qubvel/segmentation_models.pytorch

This image segmentation collection was created by Pavel Yakubovskiy, a Russian programmer. In a later article, I will also use this library to demonstrate.

Project structure

├── src: 模型的backbone以及FCN的搭建
├── train_utils: 训练、验证以及多GPU训练相关模块
├── my_dataset.py: 自定义dataset用于读取VOC数据集
├── train.py: 以fcn_resnet50(这里使用了Dilated/Atrous Convolution)进行训练
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
├── validation.py: 利用训练好的权重验证/测试数据的mIoU等指标,并生成record_mAP.txt文件
└── pascal_voc_classes.json: pascal_voc标签文件

Since there are many codes that cannot be explained one by one, we will analyze the important codes next.

Custom dataset read

my_dataset.py customizes the data reading method, the code is as follows:

import os
import torch.utils.data as data
from PIL import Image

class VOCSegmentation(data.Dataset):
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        super(VOCSegmentation, self).__init__()
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        root=root.replace('\\','/')
        assert os.path.exists(root), "path '{}' does not exist.".format(root)
        image_dir = os.path.join(root, 'JPEGImages')
        mask_dir = os.path.join(root, 'SegmentationClass')

        txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
        txt_path=txt_path.replace('\\','/')
        assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
        with open(os.path.join(txt_path), "r") as f:
            file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))
        self.transforms = transforms

Import the required packages.

Defines the VOC dataset read class VOCSegmentation. In the init method, the core is to read the image list and mask list.

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

The __getitem__ method is to obtain a single image and the mask corresponding to the image, and then perform data enhancement on it.

 def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets

The collate_fn method is to call cat_list on the data in a batch to align the data.

torch.utils.data.DataLoader call in train.py

 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)
  val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

Training

Important parameters

Open train.py, let's first understand the important parameters:

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="pytorch fcn training")
    # 数据集的根目录(VOCdevkit)所在的文件夹
    parser.add_argument("--data-path", default="data/", help="VOCdevkit root")
    parser.add_argument("--num-classes", default=20, type=int)
    parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")
    parser.add_argument("--device", default="cuda", help="training device")
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("--epochs", default=30, type=int, metavar="N",
                        help="number of total epochs to train")

    parser.add_argument('--lr', default=0.0001, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    # 是否使用混合精度训练
    parser.add_argument("--amp", default=False, type=bool,
                        help="Use torch.cuda.amp for mixed precision training")

    args = parser.parse_args()

    return args

data-path: defines the folder where the root directory of the dataset (VOCdevkit) is located

num-classes: Number of detected target classes (excluding background).

aux: Whether to use aux_classifier.

device: Use cpu or gpu for training, the default is cuda.

batch-size: BatchSize setting.

epochs: The number of epochs.

lr: learning rate.

resume: Select the model to use when continuing training.

start-epoch: The starting epoch, for retraining, it does not need to start from 0.

amp: Whether to use torch's automatic mixed precision training.

data augmentation

Enhanced calling methods in transforms.py.

The training set is augmented as follows:

class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        # 随机Resize的最小尺寸
        min_size = int(0.5 * base_size)
        # 随机Resize的最大尺寸
        max_size = int(2.0 * base_size)
        # 随机Resize增强。
        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            #随机水平翻转
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend([
            #随机裁剪
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)

Training set enhancements, including random resize, random horizontal flip, and random cropping.

Validation set enhancement:

class SegmentationPresetEval:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.RandomResize(base_size, base_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)

The enhancement of the validation set is relatively simple, only random Resize.

Main method

I made some modifications to the Main method, and the modified code is as follows:

 #定义模型,并加载预训练
    model = fcn_resnet50(pretrained=True)
    # 默认classes是21,如果不是21,则要修改类别。
    if num_classes != 21:
        model.classifier[4] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model.aux_classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    print(model)
    model.to(device)
    # 如果有多张显卡,则使用多张显卡
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

The model, I changed it to the official pytorch model. If you can use the official model, try to use the official model.

The default category is 21, if it is not 21, you need to modify the category.

Detect whether there are multiple cards in the system. If there are multiple cards, using multiple cards cannot waste resources.

If you don't want to use all the cards, but rather specify a few of them, you can use:

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

It can also be set in the DataParallel method:

model = torch.nn.DataParallel(model,device_ids=[0,1])

If multiple graphics cards are used, the parameters of the model need to be changed to model.module.xxx , for example:

  params = [p for p in model.module.aux_classifier.parameters() if p.requires_grad]
            params_to_optimize.append({"params": params, "lr": args.lr * 10})

After the above is completed, you can start training, as shown below:

test

Before starting the test, we also need to get the palette, create a new script get_palette.py, the code is as follows:

import json
import numpy as np
from PIL import Image
# 读取mask标签
target = Image.open("./2007_001288.png")
# 获取调色板
palette = target.getpalette()

palette = np.reshape(palette, (-1, 3)).tolist()
print(palette)
# 转换成字典子形式
pd = dict((i, color) for i, color in enumerate(palette))

json_str = json.dumps(pd)
with open("palette.json", "w") as f:
    f.write(json_str)

Select a mask, then use the getpalette method to get it, and then convert it to a dictionary format and save it.

Next, start the prediction part, create a new predict.py, and insert the following code:

import os
import time
import json
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from torchvision.models.segmentation import fcn_resnet50

Import the package files needed by the program, then in the mian method:

def main():
    aux = False  # inference time not need aux_classifier
    classes = 20
    weights_path = "./save_weights/model_5.pth"
    img_path = "./2007_000123.jpg"
    palette_path = "./palette.json"
    assert os.path.exists(weights_path), f"weights {weights_path} not found."
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(palette_path), f"palette {palette_path} not found."
    with open(palette_path, "rb") as f:
        pallette_dict = json.load(f)
        pallette = []
        for v in pallette_dict.values():
            pallette += v
  • Defines whether aux_classifier is required, prediction does not require aux_classifier, so set to False.
  • Set category to 20, excluding background.
  • Path to define weights.
  • Defines the path to the palette.
  • Read to go to the palette.

The next step is to load the model. There is a difference between the model trained by a single graphics card and the model trained by multiple graphics cards. Let's first look at how the model trained by a single graphics card is loaded.

   model = fcn_resnet50(num_classes=classes+1)
    print(model)
    # 单显卡训练出来的模型,加载
    # delete weights about aux_classifier
    weights_dict = torch.load(weights_path, map_location='cpu')['model']
    for k in list(weights_dict.keys()):
        if "aux_classifier" in k:
            del weights_dict[k]

    # load weights
    model.load_state_dict(weights_dict)
    model.to(device)

Define the model fcn_resnet50, num_classes is set to category + 1 (background)

Load the trained model and delete the aux_classifier.

Then load the weights.

Let's see how to load a model with multiple graphics cards

    # create model
    model = fcn_resnet50(num_classes=classes+1)
    model = torch.nn.DataParallel(model)
    # delete weights about aux_classifier
    weights_dict = torch.load(weights_path, map_location='cpu')['model']
    print(weights_dict)
    for k in list(weights_dict.keys()):
        if "aux_classifier" in k:
            del weights_dict[k]
    # load weights
    model.load_state_dict(weights_dict)
    model=model.module
    model.to(device)

Define the model fcn_resnet50, set num_classes to category + 1 (background), and put the model into the DataParallel class.

Load the trained model and delete the aux_classifier.

Load weights.

When executing torch.nn.DataParallel(model), model is placed in model.module, so model.module is the model that is really needed. So we assign model.module to model here.

Next is the processing of the image data

  # load image
    original_img = Image.open(img_path)

    # from pil image to tensor and normalize
    data_transform = transforms.Compose([transforms.Resize(520),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                              std=(0.229, 0.224, 0.225))])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

Load the image.

Resize, standardize, and normalize images.

Use torch.unsqueeze to add a dimension.

Once the processing of the image is complete, predictions can begin.

	model.eval()  # 进入验证模式
    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        output = model(img.to(device))
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        prediction = output['out'].argmax(1).squeeze(0)
        prediction = prediction.to("cpu").numpy().astype(np.uint8)
        np.set_printoptions(threshold=sys.maxsize)
        print(prediction.shape)
        mask = Image.fromarray(prediction)
        mask.putpalette(pallette)
        mask.save("test_result.png")

Save the predicted result to test_result.png. View the running results:

Original image:

result:

The data printed out:

Category List:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

From the results, it has been predicted that the class on the image is "train".

Summarize

The core content of this article is to explain how to use FCN to achieve semantic segmentation of images.

At the beginning of the article, we talked about some of the structure and advantages and disadvantages of FCN. Then, it explains how to read the dataset. Next, tell everyone how to implement training. Finally, there is the test and the result presentation. Hope this article can help you.

Full code: https://download.csdn.net/download/hhhhhhhhhhwwwwwwwww/83778007

 

Click Follow to learn about HUAWEI CLOUD's new technologies for the first time~

{{o.name}}
{{m.name}}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324122912&siteId=291194637