[3D Image Segmentation] 3D Image Segmentation 5 based on Pytorch (Training)

At the beginning of this series, a brief introduction is given to all the modules required for the training of the entire project, especially a series introduction of each structure that needs to be introduced in the training.

In the previous data construction and network model articles, each of the blocks was verified separately. Check the correctness of each module before starting training to avoid repeated problems during training.

After studying this series of articles, I believe that most of the modules have been introduced. include:

  1. In the overview, the optimizer, model acquisition and model saving are introduced;
  2. In the data flow module, we learned how to import data and verify the data flow;
  3. In the network model, the loss functionloss is called.

The greatest significance of this article is to stitch these scattered things into a whole. As for the reasoning stage, a new section will be opened separately and placed later. Through this series of studies, you can also think more and deepen your understanding.

1. Loss function

In the segmentation task, convert the target segmentation taskmask into a pixel classification task. Therefore, when calculating the loss, the loss function in the paper uses the cross entropy loss function.

In the subsequent loss improvement, introduce more dice loss or focal loss. Let's start with the cross-entropy loss function and explore why it can be used in segmentation tasks.

This article continues along the cross-entropy loss function used in the network model evaluation stage, which is defined as follows. For other segmentation loss functions, refer to this article:[AI Interview] CrossEntropy Loss, Balanced Cross Entropy, Dice Loss and Focal Loss Classification Loss Hengping Review :

1.1、CrossEntropyLoss

In the previous article about the network model, the cross-entropy loss function was introduced in the testing phase of the model. The link is here: [3D Image Segmentation] VNet 3D Image Segmentation 3 (3D UNet Model) based on Pytorch . The way to introduce loss is as follows:

expected_output_shape = (batch_size, num_out_classes, 64, 64, 64)
assert output.shape == expected_output_shape, "Unexpected output shape, check the architecture!"

# Defining loss fn
ce_layer = torch.nn.CrossEntropyLoss()
# Calculating loss
ce_loss = ce_layer(output, ground_truth)
print("CE Loss = {}".format(ce_loss))

in,

  • ground_truthThe size isBxDxHxW
  • outputThe size isBxCxDxHxW
  • For the input prediction tensor, a softmax operation is usually performed in the C dimension, so that the output value of each channel (category) is within the range of [0,1], and all channels The sum of the output values ​​is 1.
  • The purpose of this is to convert the prediction results into a probability distribution to facilitate the calculation of cross-entropy loss.
  • InPyTorch, the torch.nn.CrossEntropyLoss() function will automatically perform the softmax operation on the input.

1.2、Dice loss

DiceThe "Dice" in the coefficient is actually the abbreviation of a scientist's name, whose full name is Sørensen–Dice coefficient, often called Dice similarity coefficient or F1 score. It was independently developed by botanists Thorvald Sørensen and Lee Raymond Dice in 1948 and 1945 respectively published.

Dice coefficient is a commonsimilarity calculation method, mainly used to calculate two Similarity of sets. In Dice Loss, the Dice coefficient is used to calculate the similarity between the predicted result and the real label, hence the name Dice Loss.

dice coefficientThe definition is as follows:
1

If it is regarded as a classification task ofpixel categories, it can also be written as:
2

So,dice loss can be expressed as:
3

DiceThe Chinese name of the coefficient is "Dice similarity coefficient" or "Dice similarity", so Dice Loss can also be called "Dice similarity coefficient" a>”. Dice similarity coefficient loss” or “Dice similarity loss

multi dice lossThe definition is as follows:

import torch
import numpy as np

def one_hot_encode(label, num_classes):
    """ Torch One Hot Encode
    :param label: Tensor of shape BxHxW or BxDxHxW
    :param num_classes: K classes
    :return: label_ohe, Tensor of shape BxKxHxW or BxKxDxHxW
    """
    assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
    label_ohe = None
    if len(label.shape) == 3:
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
    elif len(label.shape) == 4:
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))

    for batch_idx, batch_el_label in enumerate(label):
        for cls in range(num_classes):
            label_ohe[batch_idx, cls] = (batch_el_label == cls)
    label_ohe = label_ohe.long()
    return label_ohe

def dice(outputs, labels):
    eps = 1e-5
    outputs, labels = outputs.float(), labels.float()
    outputs, labels = outputs.flatten(), labels.flatten()
    intersect = torch.dot(outputs, labels)  # 对应元素相乘再相加
    union = torch.add(torch.sum(outputs), torch.sum(labels))
    dice_coeff = (2 * intersect + eps) / (union + eps)
    dice_loss = 1 - dice_coeff
    return dice_loss

def dice_n_classes(outputs, labels, do_one_hot=False, get_list=False, device=None):
    """
    Computes the Multi-class classification Dice Coefficient.
    It is computed as the average Dice for all classes, each time
    considering a class versus all the others.
    Class 0 (background) is not considered in the average(不计入平均数).

    :param outputs: probabilities outputs of the CNN. Shape: [BxCxDxHxW]
    :param labels:  ground truth                      Shape: [BxDxHxW]
    :param do_one_hot: set to True if ground truth has shape [BxHxW]
    :param get_list:   set to True if you want the list of dices per class instead of average
    :param device: CUDA device on which compute the dice
    :return: Multiclass classification Dice Loss
    """
    num_classes = outputs.shape[1]
    if do_one_hot:
        labels = one_hot_encode(labels, num_classes)
        labels = labels.cuda(device=device)

    dices = list()
    for cls in range(1, num_classes):
        outputs_ = outputs[:, cls].unsqueeze(dim=1)
        labels_  = labels[:, cls].unsqueeze(dim=1)
        dice_ = dice(outputs_, labels_)
        dices.append(dice_)
    if get_list:
        return dices
    else:
        return sum(dices) / (num_classes-1)


def get_multi_dice_loss(outputs, labels, device=None):
    return dice_n_classes(outputs, labels, do_one_hot=True, get_list=False, device=device)

2. Dice coeff (coefficient) evaluation index

has already introduced when defining Dice loss, and the relationship between them is: . Dice coeffDice loss = 1- Dice coeff

In this article, although there is only one category, the case of multiple categories is still given. Dice coeff, and the average is average Dice coeff. However, since the output of this article has a background class, the background is not included in the calculation. So the calculation of Dice coeff starts from 1.

code show as below:

def one_hot_encode_np(label, num_classes):
    """ Numpy One Hot Encode
    :param label: Numpy Array of shape BxHxW or BxDxHxW
    :param num_classes: K classes
    :return: label_ohe, Numpy Array of shape BxKxHxW or BxKxDxHxW
    """
    assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
    label_ohe = None
    if len(label.shape) == 3:
        label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
    elif len(label.shape) == 4:
        label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))
    for batch_idx, batch_el_label in enumerate(label):
        for cls in range(num_classes):
            label_ohe[batch_idx, cls] = (batch_el_label == cls)
    return label_ohe

def dice_coeff(gt, pred, eps=1e-5):
    dice = np.sum(pred[gt == 1]) * 2.0 / (np.sum(pred) + np.sum(gt))
    return dice

def multi_dice_coeff(gt, pred, num_classes):
    print('loss shape:', gt.shape, pred)
    labels = one_hot_encode_np(gt, num_classes)
    outputs = one_hot_encode_np(pred, num_classes)
    dices = list()
    for cls in range(1, num_classes):
        outputs_ = outputs[:, cls]
        labels_  = labels[:, cls]
        dice_ = dice_coeff(outputs_, labels_)
        dices.append(dice_)
    return sum(dices) / (num_classes-1)

For multiple categories, before callingmulti_dice_coeff, you need to perform the following operations: (The following operations default to one situation, that is< a i=2>'s use different numbers to represent different categories, such as)targetmask0-背景;1-类别1;2-类别2;3-类别3

outputs = torch.argmax(output, dim=1)  # B x Z x Y x X
outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
labels_np = target.data.cpu().numpy()  # B x Z x Y x X
multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)

Among them, torch.argmax performs operation on category channel上 to determine which category the pixel belongs to. The obtained in this way is consistent with the method of . argmaxoutputtarget

3. Training and verification

In the overview, we have basically introduced all the fixed contents of the framework, and there seems to be nothing to expand upon in this article. Then add the larger chunks in training and validation. Coupled with the two articles on model and data flow, it is not a problem to build your own training code.

3.1. main main function part

The main function part actually coordinates the entire training main code. He included:

  1. Definition of training hyperparameters
  2. Data stream loading
  3. Creation of network models
  4. Optimizer definition
  5. Learning rate adjustment strategy
  6. Definition of loss function
  7. Training and validation function loop
  8. Saving of training process parameters
  9. Save the trained model

This process has basically been introduced in the review chapter. If you are interested, you can flip over and take a closer look. If you build it yourself, can you complete these contents completely?

The following is the code of the main function, as follows:

def main():
    Config = Configuration()
    Config.display()

    train_loader, valid_loader = get_Dataloader(Config)

    print('---start get model now---')
    model = get_model(Config).to(DEVICE)

    # ---- OPTIMIZER ----
    if Config.OPTIMR == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)
    elif Config.OPTIMR == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
    elif Config.OPTIMR == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
    elif Config.OPTIMR == "RMSProp":
        optimizer = optim.RMSprop(model.parameters(), lr=Config.LR)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.05, patience=20,
                                                           verbose=False, threshold=0.0001, threshold_mode='rel',
                                                           cooldown=0, min_lr=0, eps=1e-08)

    # Defining loss fn
    ce_layer = torch.nn.CrossEntropyLoss()

    train_loss_list = []  # 用来记录训练损失
    valid_loss_list = []  # 用来记录验证损失
    valid_dice_list = []
    epoch_list = []
    for epoch in range(1, Config.Max_epoch + 1):
        epoch_list.append(epoch)
        train_loss = train_model(model, DEVICE, train_loader, optimizer, ce_layer, epoch)  # 训练

        valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, ce_layer, epoch)   # 验证
        train_loss_list.append(train_loss)  # 记录每个epoch训练损失
        valid_loss_list.append(valid_loss)  # 验证损失
        valid_dice_list.append(valid_dice)
        draw_plot(epoch_list, valid_dice_list, 'valid_dice')
        draw_plot(epoch_list, valid_loss_list, 'valid_loss')
        draw_plot(epoch_list, train_loss_list, 'train_loss')

        if valid_dice > Config.Dice_Best:  
            path_ckpt = os.path.join(Config.model_path, 'best_model.pth')
            save_model(path_ckpt, model)
            Config.Dice_Best = valid_dice 
        else:
            path_ckpt = os.path.join(Config.model_path, 'last_model.pth')
            save_model(path_ckpt, model)

        scheduler.step(valid_loss)
    print('best val Dice is ', Config.Dice_Best)

3.2. Training part

The training process of singleepoch and the verification process of singleepoch are defined here separately. The advantage of this is that the code of the main function will be relatively concise and avoid putting it all together and indenting it too deeply, which will affect reading anyway.

The following is the training part, including:

  1. Iteration over all in a singleepochbatch
  2. Forward reasoning for a singlebatch
  3. For singlebatchpredicted result damage calculation
  4. Calculatethe prediction results of a singlebatchdice coeff
  5. Gradient clearing, reverse regression
  6. real time printing

Here is the training code:

def train_model(model, device, train_loader, optimizer, ce_layer, epoch):  # 训练模型
    config = Configuration()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    multi_dices = list()

    model.train()
    bar = Bar('Processing train ', max=len(train_loader))
    for batch_index, (data, target) in enumerate(train_loader):  # 取batch索引,(data,target),也就是图和标签
        data_time.update(time.time() - end)
        data, target = data.to(device), target.to(device)

        output = model(data)  # 图 进模型 得到预测输出
        # loss = Loss(output, target)  # 计算损失
        loss = ce_layer(output, target)
        losses.update(loss.item(), data.size(0))

        outputs = torch.argmax(output, dim=1)  # B x Z x Y x X
        outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
        labels_np = target.data.cpu().numpy()  # B x Z x Y x X
        multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
        multi_dices.append(multi_dice)

        optimizer.zero_grad()  # 梯度归零
        loss.backward()  # 反向传播
        optimizer.step()  # 优化器走一步

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        multi_dices_np = np.array(multi_dices)
        mean_multi_dice = np.mean(multi_dices_np)

        # plot progress
        bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}| LR: {lr:.6f}'.format(
            epoch=epoch,
            batch=batch_index + 1,
            size=len(train_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            dice=mean_multi_dice,
            lr=optimizer.param_groups[0]['lr']
        )
        bar.next()
    bar.finish()
    return losses.avg  # 返回平均损失

3.3. Verification part

The verification part is basically the same as the training part, except that:

  1. In the training phase,model.train(), and in the validation phase, it is requiredmodel.eval()
  2. Gradient regression is not performed to update the model during the verification phase, and the loss is only for statistical use.

Everything else is almost the same, the code is as follows:

def valid_model(model, device, test_loader, ce_layer, epoch):    # 加了个test  1是想打印时好看(区分valid和test)  2是test要打印图,需要特别设计
    config = Configuration()
    # 模型训练-----调取方法
    model.eval()  # 用来验证或测试的
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    multi_dices = list()
    bar = Bar('Processing valid ', max=len(test_loader))

    with torch.no_grad():  # 不进行 梯度计算(反向传播)
        for batch_index, (data, target) in enumerate(test_loader):  # 枚举batch索引,(图,标签)
            data_time.update(time.time() - end)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = ce_layer(output, target)
            losses.update(loss.item(), data.size(0))

            outputs = torch.argmax(output, dim=1)  # B x C x Z x Y x X   >   B x Z x Y x X
            outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
            labels_np = target.data.cpu().numpy()  # B x Z x Y x X
            multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
            multi_dices.append(multi_dice)

            multi_dices_np = np.array(multi_dices)
            mean_multi_dice = np.mean(multi_dices_np)
            std_multi_dice = np.std(multi_dices_np)

            # plot progress
            bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}'.format(
                epoch=epoch,
                batch=batch_index + 1,
                size=len(test_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                dice=mean_multi_dice
            )
            bar.next()
    bar.finish()

    return losses.avg, mean_multi_dice

3.4. Training experience

In the 3D UNet model article, we mentioned:

During the training phase of the model, there is no need to add sigmoid or softmax operations at the end. This is only required during the inference phase.

However, on the other handCrossEntropyLoss, although it is not in the model, the definition uses the sigmoid or softmax operation, but When he calculated the loss function, he secretly used the sigmoid or softmax operation.

If you do not use CrossEntropyLoss and use Dice loss, then before calculating the loss function, you need to first make a model output similar to < a normalization operation for i=3>? CrossEntropyLoss

According to my own training, I found that if no normalization operation is performed before calculating Dice loss, the gradient will easily disappear, which means it cannot converge and it will be difficult to train. This may play a standardizing role in time sigmoid or softmax and make model training easier. As for other reasons and phenomena, further additions need to be discovered.

4. Summary

Someone commented last time that the complete code was required, and this will definitely be released in the end. In a single article, the complete code has basically been posted. After a little troubleshooting, there should be no problem. Even if there are any problems, they are all simple and small problems. I have verified this.

For some beginners who don’t understand, such as python’sos file operation library, it is recommended to read other articles and write this part Complete your knowledge before continuing to study.

If an error occurs, check the error prompts for modification suggestions as soon as possible, or follow the prompts to locate the wrong place and make targeted modifications. If that doesn't work, just use Baidu. Most of the problems have already been encountered by people on the Internet. In the end, if it doesn't work, just leave a message in the comment area, and we can solve the problem together, which will be faster.

Finally, there is still one prediction chapter left, so let’s continue reading.

Guess you like

Origin blog.csdn.net/wsLJQian/article/details/134250370