Segmentation of the CARVANA dataset by UNET

Table of contents

1 Introduction

2. UNET network

3. dataset data loading

4. utils tool module

4.1 get_loaders function

4.2 check_accuracy function

4.3 save_predictions_as_imgs function

4.4 Complete code

5. train function

5.1 About imported library files

5.2 Setting hyperparameters

5.3 train_fn trains an epoch function

5.4 main function

5.5 Complete code

6. Display

6.1 Network Training

6.2 Loading pre-trained weights

6.3 Result display


Project download address: Segmentation of unet network based on CARVANA dataset

1 Introduction

The directory structure of the project is as follows:

  • Data stores training data (5056) + verification data (32)
  • saved_val_images stores the results of the network segmentation validation set

CARVANA data:

The corresponding segmentation label:

2. UNET network

UNET is named because the appearance of the network is a U shape. The left side of the network is the downsampling part, and the right side is the upsampling part.

For details, please refer to previous articles: UNET

 The construction of the unet network here is different from the previous ones. They all implement unet, but the methods are different, and both can be used.

import torch.nn as nn
import torch
import torchvision.transforms.functional as TF


# 搭建 unet 网络
class DoubleConv(nn.Module):  # 连续两次卷积
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1,stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.double_conv(x)
        return x


class UNet(nn.Module):
    def __init__(self,in_channels=3,out_channels=1,features=[64,128,256,512]): # features 存放channel数
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)

        # down sampling part of unet
        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels = feature

        # up sampling part of unet
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2)
            )
            self.ups.append(DoubleConv(feature*2,feature))

        # bottom part of unet
        self.bottleneck = DoubleConv(features[-1],features[-1]*2)

        # out layer part of unet
        self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size=1)

    def forward(self,x):
        skip_connections = []       # 尺度融合

        # down sampling
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        # down sampling
        for idx in range(0,len(self.ups),2):   # self.ups 包含了转置卷积 + DoubleConv
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx //2]

            if x.shape != skip_connection.shape:   # 保证任意输入size
                x = TF.resize(x,size = skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection,x),dim = 1)   # 转置卷积
            x = self.ups[idx+1](concat_skip)                        # DoubleConv

        x = self.final_conv(x)
        return x


# if __name__ == '__main__':
#     x = torch.rand((3,1,159,159))
#     model = UNet(in_channels=1,out_channels=1)
#     out = model(x)
#     assert x.shape == out.shape

3. dataset data loading

Similar to the previous dataset, but with some small gaps

For details, please refer to the previous article: dataset

Here is the code for the dataset:

import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np


# 数据加载
class CarvanaDataset(Dataset):
    def __init__(self,image_dir,mask_dir,transform = None):
        self.image_dir = image_dir  # 训练数据的路径
        self.mask_dir = mask_dir    # label 的路径
        self.transform = transform
        self.images = os.listdir(image_dir)     # 文件夹中的所有文件

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir,self.images[index])  # 拼接成各个数据的路径
        mask_path = os.path.join(self.mask_dir,self.images[index].replace('.jpg','_mask.gif'))  # label只是后缀的名字不同,替换掉即可

        image = np.array(Image.open(img_path).convert('RGB'))
        mask = np.array(Image.open(mask_path).convert("L"),dtype=np.float32)  # 'L' 为灰度图
        mask[mask == 255.0] = 1.0       # 变成二值图

        if self.transform is not None:
            augmentations = self.transform(image = image,mask = mask)
            image = augmentations['image']
            mask = augmentations['mask']

        return image,mask

It should be noted that the label should be set as a binary image here

The label here is a binary image, the foreground pixel is 255, and the background is 0

I had a question before, why the label is a grayscale image, not a binary image, like this

Later, it was found that it might be a display problem. After zooming in, it was found that the label was a binary image.

4. utils tool module

In order to avoid the code of the main program being too complicated, the required repeated parts are encapsulated into the utils module, and the following three parts are mainly implemented here

  • get_loaders # function to load data
  • check_accuracy # Verify the accuracy of the model
  • save_predictions_as_imgs # Save the segmented image of the model on the validation set

4.1 get_loaders function

The part of loading data is relatively simple, no different from the previous ones, here is just a simple encapsulation

Parameters that get_loader needs to pass:

  •  train_dir : the image address of the training set
  •  train_mask_dir : mask address of the training set
  •  val_dir : the image address of the validation set
  •  val_mask_dir : mask address of validation set
  •  batch_size :batch的size 
  •  transform : preprocessing
  •  num_workser: number of threads, windows needs to be set to 0, or needs to be (if __name__ == '__main__': # so that num_workers != 0 can pass)

The return value of get_loader is the training image and label, and the verified image and label

4.2 check_accuracy function

check_accuracy is a function to verify the accuracy of the model. It needs to pass in the image and label of the loader verification set, the model is the network used for verification, and the device is the device that the network runs on.

Because the binary image does not have a channel dimension, the label needs to be increased by one dimension

The output of the network is passed through the sigmoid function, and the pixels greater than 0.5 are mapped to the foreground pixels, and the pixels smaller than 0.5 are mapped to the background pixels.

DICE is defined as follows:

dice_score += ( 2*(pred * y).sum() ) / ((pred + y).sum() + 1e-8 )

4.3 save_predictions_as_imgs function

The function to save the image is shown in the figure:

  • The tensor is transformed into a numpy array to save the picture. This process is cumbersome. Pytorch provides the save_image() function, which can directly save the tensor as a picture. If the tensor is on cuda, it will also be moved to the CPU for saving.
  • In the deep learning model, the save_image() function in torchvision.utils is generally used to save images, but this method can only save RGB color images. If the output of the network is a single-channel grayscale image, the function will still output three channels, the value of each channel is the same, that is, "pseudo-grayscale image", no difference can be seen visually, but the memory occupied by the image is twice as large as normal.

4.4 Complete code

utils are as follows:

import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader


# 加载数据的参数函数
def get_loaders(train_dir,train_mask_dir,val_dir,val_mask_dir,batch_size,train_transform,val_transform,num_workers):

    # 加载训练集
    train_set = CarvanaDataset(image_dir=train_dir,mask_dir=train_mask_dir,transform=train_transform)
    train_loader = DataLoader(train_set,batch_size=batch_size,num_workers=num_workers,shuffle=True)

    # 加载验证集
    val_set = CarvanaDataset(image_dir=val_dir,mask_dir=val_mask_dir,transform=val_transform)
    val_loader = DataLoader(val_set,batch_size=batch_size,num_workers=num_workers,shuffle=False)

    return train_loader,val_loader


# 检验精度
def check_accuracy(loader,model,device):
    num_correct = 0
    num_pixels = 0
    dice_score = 0

    model.eval()            # 测试模式
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)   # add label 中的channel维度
            pred = torch.sigmoid(model(x))
            pred = (pred > 0.5 ).float()        # 转化为二值图像
            num_correct += (pred == y).sum()   # prediction 和 label中相同像素点的个数
            num_pixels += torch.numel(pred)        # 统计 y 中像素点的个数
            dice_score += ( 2*(pred * y).sum() ) / ((pred + y).sum() + 1e-8 )

    # 预测像素点正确的个数 / label
    print(
        f'Got {num_correct}/{num_pixels} with accuracy {num_correct/num_pixels*100:.2f}%'
    )
    # Dice 指标
    print(f'Dice score : {dice_score / len(loader)}')
    model.train()


# show 预测图片
def save_predictions_as_imgs(loader,model,device,folder = './saved_val_images/'):
    print('------>Loading predictions')
    model.eval()
    for idx,(x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            pred = torch.sigmoid(model(x))
            pred = (pred > 0.5).float()

        torchvision.utils.save_image(pred, f'{folder}/pred_{idx}.png')              # 保存预测图像
        torchvision.utils.save_image(y.unsqueeze(1),f'{folder}/label_{idx}.png')    # 保存label图像

    model.train()

5. train function

The train function is used to train the main function of the network

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.

When running the train function here, such an error will be reported. The simple way is to add this at the front end of the code:

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

5.1 About imported library files

import torch
import albumentations as A      # 图像增强库
from albumentations.pytorch import ToTensorV2   # 只会[h, w, c] -> [c, h, w],不会将数据归一化到[0, 1]
from tqdm import tqdm       # 进度条提示模块
import torch.nn as nn
from unet import UNet
import torch.optim as optim
# 自定义的模块
from utils import (
get_loaders,                # 加载数据
check_accuracy,             # 验证准确率
save_predictions_as_imgs,   # 预测图片
)

Some library files here are different from the previous ones, and they are all annotated

5.2 Setting hyperparameters

What needs to be noted here is LOAD_MODEL, which can be considered as a switch for whether to use pre-trained weights

If the network has been trained before and there is a saved weight file, when LOAD_MODEL is set to TRUE, the previously trained weight file will be loaded, and then the learning rate can be adjusted appropriately to continue training

5.3 train_fn trains an epoch function

code show as below

5.4 main function

Define the preprocessing of the training data:

 Define the preprocessing of validation data:

Create a model:

 Get the training and validation data from the get_loader function:

 Whether to load the pre-trained model:

 Train model + save parameters + display prediction results:

5.5 Complete code

as follows:

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


import torch
import albumentations as A      # 图像增强库
from albumentations.pytorch import ToTensorV2   # 只会[h, w, c] -> [c, h, w],不会将数据归一化到[0, 1]
from tqdm import tqdm       # 进度条提示模块
import torch.nn as nn
from unet import UNet
import torch.optim as optim
# 自定义的模块
from utils import (
get_loaders,                # 加载数据
check_accuracy,             # 验证准确率
save_predictions_as_imgs,   # 预测图片
)


# 设置超参数
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE =16
NUM_EPOCHS = 2           # epoch
NUM_WORKERS = 5
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
LOAD_MODEL = False
TRAIN_IMG_DIR = './data/train_images'
TRAIN_MASK_DIR = './data/train_masks'
VAL_IMG_DIR = './data/val_images'
VAL_MASK_DIR = './data/val_masks'


# 训练函数,一个epoch
def train_fn(loader,model,optimizer,loss_fn,scaler):
    loop = tqdm(loader)
    for batch_idx,(img,label) in enumerate(loop):
        img = img.to(device=DEVICE)
        label = label.float().unsqueeze(1).to(DEVICE)   # 增加channel维度

        # forward
        with torch.cuda.amp.autocast():     # 采用混合精度训练,不同的layer用不同的精度,达到加速训练的目的
            predictions = model(img)        # 网络输出
            loss = loss_fn(predictions,label)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss = loss.item())


def main():
    # 训练数据预处理
    train_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
            A.Rotate(limit=35,p=0.5),   # (-limit,limit)随机旋转,p=0.5 50% 概率随机旋转
            A.HorizontalFlip(p=0.5),    # 50% 概率水平翻转:沿着竖轴
            A.VerticalFlip(p=0.1),      # 10% 概率竖直翻转:沿着水平轴

            A.Normalize(                # img = (img - mean * max_pixel_value) / (std * max_pixel_value)
                mean=[0.0,0.0,0.0],
                std=[1.0,1.0,1.0],
                max_pixel_value= 255.0
                     ),
            ToTensorV2(),               # [h, w, c] -> [c, h, w]
        ]
    )
    # 验证数据预处理
    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0,0.0,0.0],
                std=[1.0,1.0,1.0],
                max_pixel_value= 255.0
                     ),
            ToTensorV2(),
        ]
    )
    # 实例化 UNet 模型 + loss + optimizer
    model = UNet(in_channels=3,out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()            # 二元交叉熵 + sigmoid
    optimizer = optim.Adam(model.parameters(),lr=LEARNING_RATE)

    # 获取数据集
    # train_loader:train_images,train_masks
    # val_loader:val_images,val_masks
    train_loader,val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transforms,
        val_transforms,
        NUM_WORKERS,
    )

    # 加载预训练权重
    if LOAD_MODEL:
        print('Pretrained:')
        model.load_state_dict(torch.load('unet.pth'))
        check_accuracy(val_loader,model,device=DEVICE)
        print('------>Loading pretrained model successfully!!')

    scaler = torch.cuda.amp.GradScaler()        # 采用混合精度,加速训练

    for epoch in range(NUM_EPOCHS):
        print('Epoch:', epoch + 1)
        train_fn(train_loader,model,optimizer,loss_fn,scaler)   # 训练一个 epoch

        # check accuracy
        check_accuracy(val_loader,model,device=DEVICE)

    # save model
    print('------>Saving checkpoint')
    torch.save(model.state_dict(),'unet.pth')

    # print some examples to a folder
    save_predictions_as_imgs(val_loader,model,folder='saved_val_images/',device=DEVICE)


if __name__ == '__main__':      # 这样num_workers != 0 才可以通过
    main()
    print(' training over!!!! ')

6. Display

6.1 Network Training

The network trained two epoch results 

Here 316 is because of samples/batch_size: 5056 / 16 = 316

6.2 Loading pre-trained weight training

LOAD_MODEL = True

6.3 Result display

Network predictions:

Real label:

Network predictions:

Real label:

Guess you like

Origin blog.csdn.net/qq_44886601/article/details/129350075#comments_27392764