UNet semantic segmentation combat: using UNet to realize the matting of characters

Summary

In the last article, I summarized some basic knowledge of UNet. If you don't know about UNet, you can read the article link: https://wanghao.blog.csdn.net/article/details/123714994

I also compiled the pytorch version of UNet, the article link:

https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/123280059

Today's article explains how to use UNet to achieve binary classification of images.

There are generally two approaches to binary classification:

The first output is single-channel, that is, the output outputof shape [batch_size, 1, height, width]. where batch_szieis the batch size, which 1means to output one channel, heightand widthis consistent with the height and width of the input image.

During training, the number of output channels is 1, and the values ​​obtained by outputthe are arbitrary numbers. Given target, is a single-channel label map with only 0 and 1 values. In order to make the network output outputcontinue to approach this label, first outputpass a sigmoid function to normalize its value to [0, 1] to obtain output1, and then let this output1and target, perform cross entropy calculation to obtain the loss value, and backpropagation to update the network Weights. Eventually, the network learns to output1approximate target.

After training, the network outputhas targetthe ability to transform the output to approximate . First pass the output outputthrough sigmoid function, and then take a threshold (usually set to 0.5), if it is greater than the threshold, take 1, otherwise take 0, so as to obtain the prediction map predict. Follow-up is some evaluation-related calculations.

If the last layer of the network uses sigmoid, choose BCELoss, if not, choose BCEWithLogitsLoss, for example:

No sigmod on the last layer

output = net(input)  # net的最后一层没有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)

plus sigmod

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)

When predicted:

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
predict=torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))

The second output is multi-channel, that is, the output of the network outputis [batch_size, num_class, height, width] shape. where batch_szieis the batch size, num_classindicating that the number of output channels is consistent with the number of classifications, andheight is consistent with the height and width of the input image.width

During training, the number of output channels is num_class(here we take 2). Given target, is a single-channel label map with only 0 and 1 values. In order to make the network output outputcontinue to approach this label, first, it will go outputthrough a softmax function to normalize its value to [0, 1], output1and in each channel, this value will add up to 1. For targetit is a single-channel image, first use onehotencoding to convert it into num_classa channel image, and the value in each channel is calculated based on the value in the single channel, for example, the first pixel in the single channel takes the value of 1 (0<= 1 <=num_class-1, where num_class=2), then onehotafter encoding, at the position of the first pixel, the values ​​of the two channels are 0 and 1, respectively. That is to say, the value of the pixel determines that the channel corresponding to the serial number is 1, and the other channels are 0, which is very critical. After the above operation is performed target1, let this output1and target1, perform cross entropy calculation, get the loss value, and update the network weight by backpropagation. Eventually, the network learns to output1approximate target1(at each channel level).

After training, the network outputhas targetthe ability to transform the output to approximate . outputAt each pixel position of each channel in the calculation, the corresponding channel number with the largest value is obtained to obtain the prediction map predict.

The loss used for training selection is to add a loss function, for example:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)

When forecasting

output = net(input)  # net的最后一层没有使用sigmoid
predict = output.argmax(dim=1)

The second method used in this actual combat.

Selected code address: milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)

After downloading the code, unzip it locally, as shown below:

image-20220406094337124

data set

Dataset address: http://www.cse.cuhk.edu.hk/~leojia/projects/automatting/, published in 2016.

The data set contains 2000 images, 1700 training images, and 300 testing images. The data are all portrait images from Flickr. The original resolution of the images is 600×800, and the matting is generated by closed-form matting and KNN matting.

Due to the high commercial value of the portrait segmentation dataset, there are few public large-scale datasets. This dataset is one of the earliest published and widely used datasets. It has several important features:

(1) The image resolution is uniform, the shooting is clear, and the quality is very high.

(2) All images are portraits of the upper body, and the portrait area occupies at least 2/3 of the image in both length and width.

(3) The poses of the characters change very little, and they are all front views from a small angle, and the background is relatively simple.

img

img

img

[1] Shen X, Tao X, Gao H, et al. Deep Automatic Portrait Matting[M]// ComputerVision – ECCV 2016. Springer International Publishing, 2016:92-107.

After downloading the dataset, put the training set in the data folder, where the images are placed in the imgs folder, the mask is placed in the masks folder, and the test set is placed under the test folder:

image-20220406094225993

Since the original program is used for the Carvana Image Masking Challenge , we need to modify the logic of loading the dataset and open the utils/data_loading.py file:

class CarvanaDataset(BasicDataset):
    def __init__(self, images_dir, masks_dir, scale=1):
        super().__init__(images_dir, masks_dir, scale, mask_suffix='_matte')

Change mask_suffix to "_matte"

train

Open train.py and check the global parameters first:

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=300, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    return parser.parse_args()

epochs: The number of epochs, generally set to 300.

batch-size: The size of the batch, set according to the size of the video memory.

learning-rate: The learning rate, generally set to 0.001, if the optimizer is different, the initial learning rate should also be adjusted accordingly.

load: The path to load the model. If you continue the last training, you need to set the weight file path of the last training. If there is pre-training weight, set the path of the pre-training weight.

scale: the magnification factor, which is set to 0.5 here, and the image size is changed to half of the original size.

validation: Validation percentage of the validation set.

amp: Are you using mixed precision?

The more important parameters are epochs, batch-size and learning-rate, which can be adjusted repeatedly to achieve the best accuracy.

Next is to set up the model:

net = UNet(n_channels=3, n_classes=2, bilinear=True)
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

Set the UNet parameter, n_channels is the channel number of imgs image, if it is rgb, it is 3, if it is a black and white image, it is 1, n_classes is set to 2, here the background is also regarded as a category, so there are two classes.

If a weight file is set, load the weight file, and load the weight file for transfer learning to speed up training and reduce the number of iterations, so try to load pre-training weights if possible.

Next, modify the logic of the train_net function.

try:
    dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError):
    dataset = BasicDataset(dir_img, dir_mask, img_scale)

# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. Load the dataset.

2. Divide the training set and the validation set according to the proportion.

3. Put the training set and validation set into the DataLoader.

 # (Initialize logging)
 experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
 experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
                                  amp=amp))

Setting up wandb, wandb is a very useful visualization tool. For installation and usage methods, see: https://blog.csdn.net/hhhhhhhhhhwwwwwwwww/article/details/116124285.

 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

1. Set the optimizer optimizer to RMSprop. I also tried to change it to SGD. Usually, SGD performs better. However, during training, it was found that the final results of the two were similar.

2. ReduceLROnPlateau learning rate adjustment strategy, similar to keras. This time, the Dice score is used, so the mode is set to max, and when the score no longer rises, the learning rate is reduced.

3. Set loss to nn.CrossEntropyLoss(). Cross entropy, a commonly used loss for multi-classification.

Next is the logic of the train part, which needs to be modified as follows:

 masks_pred = net(images)
 true_masks = F.one_hot(true_masks.squeeze_(1), net.n_classes).permute(0, 3, 1, 2).float()
 print(masks_pred.shape)
 print(true_masks.shape)

The result calculated by masks_pred = net(images) is: [batch, 2, 400, 300], where 2 represents two categories.

true_masks.shape is [batch, 1, 400, 300], so do onehot processing on true_masks. If you do onehot processing directly on true_masks, you will find that the processed shape is [batch, 1, 400, 300, 2], which is incompatible with masks_pred, so before doing onehot, first make the second dimension (also The dimension of 1) is removed, so that the shape after onehot is [batch, 400, 300, 2], and then adjust the order to match the dimension of masks_pred.

The next step is to calculate the loss. The loss is divided into two parts, one part is cross entropy, and the other part is dice_loss. These two losses have their own advantages, and the combined effect is better. dice_loss is in the utils/dice_sorce.py file, the code is as follows:

import torch
from torch import Tensor

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {
      
      input.shape})')
    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter
        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]

def dice_coeff_1(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
    return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

Import it into train.py, and then combine it with cross entropy as the loss of this project.

 loss = criterion(masks_pred, true_masks) \
        + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       true_masks,
                                       multiclass=True)

The next step is to modify the logic of the evaluate function.

 mask_true = mask_true.to(device=device, dtype=torch.long)
 mask_true = F.one_hot(mask_true.squeeze_(1), net.n_classes).permute(0, 3, 1, 2).float()

Added onehot logic to mask_true.

After modifying the above logic, you can start training.

image-20220406111550682

test

After training, you can test it. Open predict.py and modify the global parameters:

def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', default='checkpoints/checkpoint_epoch7.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')
    parser.add_argument('--input', '-i', metavar='INPUT',default='test/00002.png', nargs='+', help='Filenames of input images')
    parser.add_argument('--output', '-o', metavar='INPUT',default='00001.png', nargs='+', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true',default=False, help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5,
                        help='Scale factor for the input images')

model: Set the weight file path. This is modified to the weight file trained by yourself.

scale: 0.5, corresponding to the training parameters.

Other parameters are entered through the command.

def mask_to_image(mask: np.ndarray):
    if mask.ndim == 2:
        return Image.fromarray((mask * 255).astype(np.uint8))
    elif mask.ndim == 3:
        img_np=(np.argmax(mask, axis=0) * 255 / (mask.shape[0]-1)).astype(np.uint8)
        print(img_np.shape)
        print(np.max(img_np))
        return Image.fromarray(img_np)

img_np=(np.argmax(mask, axis=0) * 255 / (mask.shape[0]-1)).astype(np.uint8) The logic here needs to be modified.

Source code:

 return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))

We added a class of background, so mask.shape[0] is 2, and the background needs to be subtracted.

The method of presenting the results also needs to be modified;

def plot_img_and_mask(img, mask):
    print(mask.shape)
    classes = mask.shape[0] if len(mask.shape) > 2 else 1
    fig, ax = plt.subplots(1, classes + 1)
    ax[0].set_title('Input image')
    ax[0].imshow(img)
    if classes > 1:
        for i in range(classes):
            ax[i + 1].set_title(f'Output mask (class {
      
      i + 1})')
            ax[i + 1].imshow(mask[i, :, :])
    else:
        ax[1].set_title(f'Output mask')
        ax[1].imshow(mask)
    plt.xticks([]), plt.yticks([])
    plt.show()

Change the original ax[i + 1].imshow(mask[:, :, i]) to: ax[i + 1].imshow(mask[i, :, :]).

Excuting an order:

python predict.py -i test/00002.png -o output.png  -v 

Output result:

image-20220406124311843

At this point, we have achieved the complete cutout of the character from the background image!

Summarize

This article implements image segmentation with Unet. Through this article, you can learn:

1. How to use Unet for semantic segmentation of image pair two-classification.

2. How to use wandb visualization.

3. How to use the combination of cross entropy and dice_loss.

4. How to realize the prediction of two-category semantic segmentation.

Complete code:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwww/85083165

Guess you like

Origin blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/123987321