pytorch trains FCN with voc split dataset

Semantic segmentation is the process of classifying each pixel in the image to complete the image segmentation. Segmentation is mainly used in the field of medical images and autonomous driving.

A Survey of Image Semantic Segmentation

Like other algorithms, the development process of image segmentation has also experienced the transformation from traditional algorithms to deep learning algorithms. Traditional segmentation algorithms include threshold segmentation, watershed, edge detection, etc., and the problems they face are the same as other traditional image processing algorithms. The stickiness is not enough, but in some occasions where the scene is single and unchanged, traditional image processing is still used more.

FCN is a paper in 2014, the pioneering work of deep learning semantic segmentation, which laid the foundation for semantic segmentation ideologically.

Fully Convolutional Networks for Semantic Segmentation
Submitted on 14 Nov 2014
https://arxiv.org/abs/1411.4038

1. Introduction to FCN theory

 The above picture is a screenshot of the original paper, which depicts the network architecture of FCN from the overall architecture. In fact, the image undergoes a series of convolution operations, and then upsampled to the size of the original image, and the category probability of each pixel is output .

 The above figure describes the FCN network in more detail. The backbone uses VGG16, and the fully-connect layer of VGG is represented by convolution, that is, conv6-7 (a convolution kernel with the same size as feature_map is equivalent to full connection). In general, the network has the following key points:

1. Fully Convolution : It is used to solve the prediction problem of pixels. By replacing the last fully connected layer of the basic network (such as VGG16) with a convolutional layer, an image input of any size can be realized, and the output image size corresponds to the input;

2.  Transpose Convolution : The upsampling process is used to restore the size of the image to facilitate subsequent pixel-by-pixel prediction;

3. Skip Architecture : Used to fuse high-level and low-level feature information. Because convolution is a downsampling operation, although transposed convolution restores the image size, it is not the inverse operation of convolution after all, so information must be lost, and skip architecture can integrate the fine-grained information of thousands of layers and the coarse-grained depth of layers information to improve the fineness of segmentation.

 

FCN-32s: There is no jump connection, and the transposed convolution of each layer is enlarged by 2 times, and after five layers, it is enlarged by 32 times to restore the original image size.

FCN-16s: A skip-connect, (1/32) enlarged to (1/16), then added to vgg (1/16), and then continue to enlarge until the original image size. 

FCN-8s: Two skip-connects, one is (1/32) enlarged to (1/16), and then added to vgg (1/16); the other is (1/16) enlarged to (1 /8) After that, add it to vgg (1/8), and then continue to zoom in until the size of the original image.

2. Training process

The pytorch training deep learning model mainly needs to implement three files, namely data.py, model.py, and train.py. Among them, data.py implements the data batch processing function, model.py defines the network model, and train.py implements the training steps.

2.1 Introduction to voc dataset

 

Download address: Pascal VOC Dataset Mirror

The name of the picture is in /ImageSets/Segmentation/train.txt ans val.txt.
The pictures are all under the ./data/VOC2012/JPEGImages folder. You need to add the .jpg
tag after each line read in train.txt. Under the /data/VOC2012/SegmentationClass folder, you need to add .png after each line read

voc_seg_data.py

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader,Dataset
import numpy as np
import os
from PIL import Image
from datetime import datetime



class VOC_SEG(Dataset):
    def __init__(self, root, width, height, train=True, transforms=None):
        # 图像统一剪切尺寸(width, height)
        self.width = width
        self.height = height
        # VOC数据集中对应的标签
        self.classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
        # 各种标签所对应的颜色
        self.colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]
        # 辅助变量
        self.fnum = 0
        if transforms is None:
            normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            self.transforms = T.Compose([
                T.ToTensor(),
                normalize
            ])
        # 像素值(RGB)与类别label(0,1,3...)一一对应
        self.cm2lbl = np.zeros(256**3)
        for i, cm in enumerate(self.colormap):
            self.cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i

        
        if train:
            txt_fname = root+"/ImageSets/Segmentation/train.txt"
        else:
            txt_fname = root+"/ImageSets/Segmentation/val.txt"
        with open(txt_fname, 'r') as f:
            images = f.read().split()
        imgs = [os.path.join(root, "JPEGImages", item+".jpg") for item in images]
        labels = [os.path.join(root, "SegmentationClass", item+".png") for item in images]
        self.imgs = self._filter(imgs)
        self.labels = self._filter(labels)
        if train:
            print("训练集:加载了 " + str(len(self.imgs)) + " 张图片和标签" + ",过滤了" + str(self.fnum) + "张图片")
        else:
            print("测试集:加载了 " + str(len(self.imgs)) + " 张图片和标签" + ",过滤了" + str(self.fnum) + "张图片")

    def _crop(self, data, label):
        """
        切割函数,默认都是从图片的左上角开始切割。切割后的图片宽是width,高是height
        data和label都是Image对象
        """
        box = (0,0,self.width,self.height)
        data = data.crop(box)
        label = label.crop(box)
        return data, label

    def _image2label(self, im):
        data = np.array(im, dtype="int32")
        idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2]
        return np.array(self.cm2lbl[idx], dtype="int64")
        
    def _image_transforms(self, data, label):
        data, label = self._crop(data,label)
        data = self.transforms(data)
        label = self._image2label(label)
        label = torch.from_numpy(label)
        return data, label

    def _filter(self, imgs): 
        img = []
        for im in imgs:
            if (Image.open(im).size[1] >= self.height and 
               Image.open(im).size[0] >= self.width):
                img.append(im)
            else:
                self.fnum  = self.fnum+1
        return img

    def __getitem__(self, index: int):
        img_path = self.imgs[index]
        label_path = self.labels[index]
        img = Image.open(img_path)
        label = Image.open(label_path).convert("RGB")
        img, label = self._image_transforms(img, label)
        return img, label

    def __len__(self) :
        return len(self.imgs)



if __name__=="__main__":
    root = "./VOCdevkit/VOC2012"
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)

    # train_data = DataLoader(voc_train, batch_size=8, shuffle=True)
    # valid_data = DataLoader(voc_test, batch_size=8)
    for data, label in voc_train:
        print(data.shape)
        print(label.shape)
        break



  • In order to save trouble, I have written some auxiliary functions, such as _crop(), _filter(), and the variable colormap, etc. into the class. In fact, it is better to write a data preprocessing file separately, so that after the training is over, the corresponding processing function can be directly called during the reasoning test.
  • The result of data processing is to get data, label. data is an image in tensor format, label is also tensor, and the pixel (RGB) has been replaced with an int category number. In this way, during training, the cross-entropy function will directly implement one-hot processing, just like training a classification network.
     

2.2 Network Definition

fcn8s_net.py

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchsummary import summary
from torchvision import models


class FCN8s(nn.Module):
    def __init__(self, num_classes=21):
        super(FCN8s,self).__init__()
        net = models.vgg16(pretrained=True)   # 从预训练模型加载VGG16网络参数
        self.premodel = net.features          # 只使用Vgg16的五层卷积层(特征提取层)(3,224,224)----->(512,7,7)

        # self.conv6 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1) 
        # self.conv7 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1)
        # (512,7,7)
        self.relu = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512,512,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn1 = nn.BatchNorm2d(512)
        # (512, 14, 14)
        self.deconv2 = nn.ConvTranspose2d(512,256,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn2 = nn.BatchNorm2d(256)
        # (256, 28, 28)
        self.deconv3 = nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn3 = nn.BatchNorm2d(128)
        # (128, 56, 56)
        self.deconv4 = nn.ConvTranspose2d(128,64,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)   # x2
        self.bn4 = nn.BatchNorm2d(64)
        # (64, 112, 112)
        self.deconv5 = nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)    # x2
        self.bn5 = nn.BatchNorm2d(32)
        # (32, 224, 224)
        self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
        # (num_classes, 224, 224)
        

    def forward(self, input):
        x = input
        for i in range(len(self.premodel)):
            x = self.premodel[i](x)
            if i == 16:
                x3 = x  # maxpooling3的feature map (1/8)
            if i == 23:
                x4 = x  # maxpooling4的feature map (1/16)
            if i == 30:
                x5 = x  # maxpooling5的feature map (1/32)

        # 五层转置卷积,每层size放大2倍,与VGG16刚好相反。两个skip-connect
        score = self.relu(self.deconv1(x5))   # out_size = 2*in_size (1/16)
        score = self.bn1(score + x4)

        score = self.relu(self.deconv2(score)) # out_size = 2*in_size (1/8)  
        score = self.bn2(score + x3)

        score = self.bn3(self.relu(self.deconv3(score)))  # out_size = 2*in_size (1/4)
        score = self.bn4(self.relu(self.deconv4(score)))  # out_size = 2*in_size (1/2)
        score = self.bn5(self.relu(self.deconv5(score)))  # out_size = 2*in_size (1)

        score = self.classifier(score)                    # size不变,使输出的channel等于类别数

        return score

if __name__ == "__main__":
    model = FCN8s()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print(model)

  • In terms of the implementation of FCN's network code, there are differences in the online search, but the overall structure is convolution + transposed convolution + skip links. In fact, it is enough to implement feature extraction (extract abstract features)-transpose convolution (restore the original image size)-classify each pixel.
  • This experiment uses the five-layer convolutional layer of vgg16 as the feature extraction network, and then connects five transposed convolutions (2x) to restore the size of the original image, and then connects a convolutional layer to adjust the channel of the feature map to the number of categories (twenty one). Finally, softmax classification is enough.

2.3 Training

train.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc_seg_data import VOC_SEG
from fcn_net import FCN8s
import os
import numpy as np
 

# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


# 根据混淆矩阵计算Acc和mIou
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu


def main():
    # 1. load dataset
    root = "./VOCdevkit/VOC2012"
    batch_size = 32
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)
    train_dataloader = DataLoader(voc_train,batch_size=batch_size,shuffle=True)
    val_dataloader = DataLoader(voc_test,batch_size=batch_size,shuffle=True)
    
    # 2. load model
    num_class = 21
    model = FCN8s(num_classes=num_class)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # 3. prepare super parameters
    criterion = nn.CrossEntropyLoss() 
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.7)
    epoch = 50
 
    # 4. train
    val_acc_list = []
    out_dir = "./checkpoints/"
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    for epoch in range(0, epoch):
        print('\nEpoch: %d' % (epoch + 1))
        model.train()
        sum_loss = 0.0
        for batch_idx, (images, labels) in enumerate(train_dataloader):
            length = len(train_dataloader)
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images) # torch.size([batch_size, num_class, width, height])
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            sum_loss += loss.item()
            predicted = torch.argmax(outputs.data, 1)
            
            label_pred = predicted.data.cpu().numpy()
            label_true = labels.data.cpu().numpy()
            acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)
            
            print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% | Acc_cls: %.03f%% |Mean_iu: %.3f' 
                % (epoch + 1, (batch_idx + 1 + epoch * length), sum_loss / (batch_idx + 1), 
                100. *acc, 100.*acc_cls, mean_iu))
            
        #get the ac with testdataset in each epoch
        print('Waiting Val...')
        mean_iu_epoch = 0.0
        mean_acc = 0.0
        mean_acc_cls = 0.0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_dataloader):
                model.eval()
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predicted = torch.argmax(outputs.data, 1)


                label_pred = predicted.data.cpu().numpy()
                label_true = labels.data.cpu().numpy()
                acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)

                # total += labels.size(0)
                # iou = torch.sum((predicted == labels.data), (1,2)) / float(width*height)
                # iou = torch.sum(iou)
                # correct += iou
                mean_iu_epoch += mean_iu
                mean_acc += acc
                mean_acc_cls += acc_cls
            
            print('Acc_epoch: %.3f%% | Acc_cls_epoch: %.03f%% |Mean_iu_epoch: %.3f' 
                % ((100. *mean_acc / len(val_dataloader)), (100.*mean_acc_cls/len(val_dataloader)), mean_iu_epoch/len(val_dataloader)) )
            
            val_acc_list.append(mean_iu_epoch/len(val_dataloader))
 
 
        torch.save(model.state_dict(), out_dir+"last.pt")
        if mean_iu_epoch/len(val_dataloader) == max(val_acc_list):
            torch.save(model.state_dict(), out_dir+"best.pt")
            print("save epoch {} model".format(epoch))
 
if __name__ == "__main__":
    main()

 The overall training process is fine, and readers can change their model evaluation criteria and related codes as needed. In this training, Acc is mainly used as the evaluation index, which is actually the number of correctly classified pixels divided by the number of all pixels. The final training results are as follows:

0.8

The Acc of the training set came to 0.8, and the Acc of the verification set came to 0.77. Since some functions are copied, such as _hist, etc., other indicators are not referenced for the time being.

Guess you like

Origin blog.csdn.net/Eyesleft_being/article/details/121676803