At the beginning of this series, a brief introduction is given to all the modules required for the training of the entire project, especially a series introduction of each structure that needs to be introduced in the training.
In the previous data construction and network model articles, each of the blocks was verified separately. Check the correctness of each module before starting training to avoid repeated problems during training.
After studying this series of articles, I believe that most of the modules have been introduced. include:
- In the overview, the optimizer, model acquisition and model saving are introduced;
- In the data flow module, we learned how to import data and verify the data flow;
- In the network model, the loss function
loss
is called.
The greatest significance of this article is to stitch these scattered things into a whole. As for the reasoning stage, a new section will be opened separately and placed later. Through this series of studies, you can also think more and deepen your understanding.
1. Loss function
In the segmentation task, convert the target segmentation taskmask
into a pixel classification task. Therefore, when calculating the loss, the loss function in the paper uses the cross entropy loss function.
In the subsequent loss improvement, introduce more dice loss
or focal loss
. Let's start with the cross-entropy loss function and explore why it can be used in segmentation tasks.
This article continues along the cross-entropy loss function used in the network model evaluation stage, which is defined as follows. For other segmentation loss functions, refer to this article:[AI Interview] CrossEntropy Loss, Balanced Cross Entropy, Dice Loss and Focal Loss Classification Loss Hengping Review :
1.1、CrossEntropyLoss
In the previous article about the network model, the cross-entropy loss function was introduced in the testing phase of the model. The link is here: [3D Image Segmentation] VNet 3D Image Segmentation 3 (3D UNet Model) based on Pytorch . The way to introduce loss
is as follows:
expected_output_shape = (batch_size, num_out_classes, 64, 64, 64)
assert output.shape == expected_output_shape, "Unexpected output shape, check the architecture!"
# Defining loss fn
ce_layer = torch.nn.CrossEntropyLoss()
# Calculating loss
ce_loss = ce_layer(output, ground_truth)
print("CE Loss = {}".format(ce_loss))
in,
ground_truth
The size isBxDxHxW
output
The size isBxCxDxHxW
- For the input prediction tensor, a softmax operation is usually performed in the C dimension, so that the output value of each channel (category) is within the range of
[0,1]
, and all channels The sum of the output values is 1. - The purpose of this is to convert the prediction results into a probability distribution to facilitate the calculation of cross-entropy loss.
- In
PyTorch
, thetorch.nn.CrossEntropyLoss()
function will automatically perform thesoftmax
operation on the input.
1.2、Dice loss
Dice
The "Dice" in the coefficient is actually the abbreviation of a scientist's name, whose full name is Sørensen–Dice coefficient
, often called Dice similarity coefficient
or F1 score
. It was independently developed by botanists Thorvald Sørensen
and Lee Raymond Dice
in 1948
and 1945
respectively published.
Dice coefficient is a commonsimilarity calculation method, mainly used to calculate two Similarity of sets. In Dice Loss
, the Dice coefficient is used to calculate the similarity between the predicted result and the real label, hence the name Dice Loss
.
dice coefficient
The definition is as follows:
If it is regarded as a classification task ofpixel categories, it can also be written as:
So,dice loss
can be expressed as:
Dice
The Chinese name of the coefficient is "Dice similarity coefficient" or "Dice similarity", so Dice Loss can also be called "Dice similarity coefficient" a>”. Dice similarity coefficient loss” or “Dice similarity loss
multi dice loss
The definition is as follows:
import torch
import numpy as np
def one_hot_encode(label, num_classes):
""" Torch One Hot Encode
:param label: Tensor of shape BxHxW or BxDxHxW
:param num_classes: K classes
:return: label_ohe, Tensor of shape BxKxHxW or BxKxDxHxW
"""
assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
label_ohe = None
if len(label.shape) == 3:
label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
elif len(label.shape) == 4:
label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))
for batch_idx, batch_el_label in enumerate(label):
for cls in range(num_classes):
label_ohe[batch_idx, cls] = (batch_el_label == cls)
label_ohe = label_ohe.long()
return label_ohe
def dice(outputs, labels):
eps = 1e-5
outputs, labels = outputs.float(), labels.float()
outputs, labels = outputs.flatten(), labels.flatten()
intersect = torch.dot(outputs, labels) # 对应元素相乘再相加
union = torch.add(torch.sum(outputs), torch.sum(labels))
dice_coeff = (2 * intersect + eps) / (union + eps)
dice_loss = 1 - dice_coeff
return dice_loss
def dice_n_classes(outputs, labels, do_one_hot=False, get_list=False, device=None):
"""
Computes the Multi-class classification Dice Coefficient.
It is computed as the average Dice for all classes, each time
considering a class versus all the others.
Class 0 (background) is not considered in the average(不计入平均数).
:param outputs: probabilities outputs of the CNN. Shape: [BxCxDxHxW]
:param labels: ground truth Shape: [BxDxHxW]
:param do_one_hot: set to True if ground truth has shape [BxHxW]
:param get_list: set to True if you want the list of dices per class instead of average
:param device: CUDA device on which compute the dice
:return: Multiclass classification Dice Loss
"""
num_classes = outputs.shape[1]
if do_one_hot:
labels = one_hot_encode(labels, num_classes)
labels = labels.cuda(device=device)
dices = list()
for cls in range(1, num_classes):
outputs_ = outputs[:, cls].unsqueeze(dim=1)
labels_ = labels[:, cls].unsqueeze(dim=1)
dice_ = dice(outputs_, labels_)
dices.append(dice_)
if get_list:
return dices
else:
return sum(dices) / (num_classes-1)
def get_multi_dice_loss(outputs, labels, device=None):
return dice_n_classes(outputs, labels, do_one_hot=True, get_list=False, device=device)
2. Dice coeff (coefficient) evaluation index
has already introduced when defining Dice loss
, and the relationship between them is: . Dice coeff
Dice loss = 1- Dice coeff
In this article, although there is only one category, the case of multiple categories is still given. Dice coeff
, and the average is average Dice coeff
. However, since the output of this article has a background class, the background is not included in the calculation. So the calculation of Dice coeff
starts from 1
.
code show as below:
def one_hot_encode_np(label, num_classes):
""" Numpy One Hot Encode
:param label: Numpy Array of shape BxHxW or BxDxHxW
:param num_classes: K classes
:return: label_ohe, Numpy Array of shape BxKxHxW or BxKxDxHxW
"""
assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
label_ohe = None
if len(label.shape) == 3:
label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
elif len(label.shape) == 4:
label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))
for batch_idx, batch_el_label in enumerate(label):
for cls in range(num_classes):
label_ohe[batch_idx, cls] = (batch_el_label == cls)
return label_ohe
def dice_coeff(gt, pred, eps=1e-5):
dice = np.sum(pred[gt == 1]) * 2.0 / (np.sum(pred) + np.sum(gt))
return dice
def multi_dice_coeff(gt, pred, num_classes):
print('loss shape:', gt.shape, pred)
labels = one_hot_encode_np(gt, num_classes)
outputs = one_hot_encode_np(pred, num_classes)
dices = list()
for cls in range(1, num_classes):
outputs_ = outputs[:, cls]
labels_ = labels[:, cls]
dice_ = dice_coeff(outputs_, labels_)
dices.append(dice_)
return sum(dices) / (num_classes-1)
For multiple categories, before callingmulti_dice_coeff
, you need to perform the following operations: (The following operations default to one situation, that is< a i=2>'s use different numbers to represent different categories, such as)target
mask
0-背景;1-类别1;2-类别2;3-类别3
outputs = torch.argmax(output, dim=1) # B x Z x Y x X
outputs_np = outputs.data.cpu().numpy() # B x Z x Y x X
labels_np = target.data.cpu().numpy() # B x Z x Y x X
multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
Among them, torch.argmax
performs operation on category channel上
to determine which category the pixel belongs to. The obtained in this way is consistent with the method of . argmax
output
target
3. Training and verification
In the overview, we have basically introduced all the fixed contents of the framework, and there seems to be nothing to expand upon in this article. Then add the larger chunks in training and validation. Coupled with the two articles on model and data flow, it is not a problem to build your own training code.
3.1. main main function part
The main function part actually coordinates the entire training main code. He included:
- Definition of training hyperparameters
- Data stream loading
- Creation of network models
- Optimizer definition
- Learning rate adjustment strategy
- Definition of loss function
- Training and validation function loop
- Saving of training process parameters
- Save the trained model
This process has basically been introduced in the review chapter. If you are interested, you can flip over and take a closer look. If you build it yourself, can you complete these contents completely?
The following is the code of the main function, as follows:
def main():
Config = Configuration()
Config.display()
train_loader, valid_loader = get_Dataloader(Config)
print('---start get model now---')
model = get_model(Config).to(DEVICE)
# ---- OPTIMIZER ----
if Config.OPTIMR == "SGD":
optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)
elif Config.OPTIMR == "Adam":
optimizer = optim.Adam(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
elif Config.OPTIMR == "AdamW":
optimizer = optim.AdamW(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
elif Config.OPTIMR == "RMSProp":
optimizer = optim.RMSprop(model.parameters(), lr=Config.LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.05, patience=20,
verbose=False, threshold=0.0001, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-08)
# Defining loss fn
ce_layer = torch.nn.CrossEntropyLoss()
train_loss_list = [] # 用来记录训练损失
valid_loss_list = [] # 用来记录验证损失
valid_dice_list = []
epoch_list = []
for epoch in range(1, Config.Max_epoch + 1):
epoch_list.append(epoch)
train_loss = train_model(model, DEVICE, train_loader, optimizer, ce_layer, epoch) # 训练
valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, ce_layer, epoch) # 验证
train_loss_list.append(train_loss) # 记录每个epoch训练损失
valid_loss_list.append(valid_loss) # 验证损失
valid_dice_list.append(valid_dice)
draw_plot(epoch_list, valid_dice_list, 'valid_dice')
draw_plot(epoch_list, valid_loss_list, 'valid_loss')
draw_plot(epoch_list, train_loss_list, 'train_loss')
if valid_dice > Config.Dice_Best:
path_ckpt = os.path.join(Config.model_path, 'best_model.pth')
save_model(path_ckpt, model)
Config.Dice_Best = valid_dice
else:
path_ckpt = os.path.join(Config.model_path, 'last_model.pth')
save_model(path_ckpt, model)
scheduler.step(valid_loss)
print('best val Dice is ', Config.Dice_Best)
3.2. Training part
The training process of singleepoch
and the verification process of singleepoch
are defined here separately. The advantage of this is that the code of the main function will be relatively concise and avoid putting it all together and indenting it too deeply, which will affect reading anyway.
The following is the training part, including:
- Iteration over all in a single
epoch
batch
- Forward reasoning for a single
batch
- For single
batch
predicted result damage calculation - Calculatethe prediction results of a single
batch
dice coeff
- Gradient clearing, reverse regression
- real time printing
Here is the training code:
def train_model(model, device, train_loader, optimizer, ce_layer, epoch): # 训练模型
config = Configuration()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
multi_dices = list()
model.train()
bar = Bar('Processing train ', max=len(train_loader))
for batch_index, (data, target) in enumerate(train_loader): # 取batch索引,(data,target),也就是图和标签
data_time.update(time.time() - end)
data, target = data.to(device), target.to(device)
output = model(data) # 图 进模型 得到预测输出
# loss = Loss(output, target) # 计算损失
loss = ce_layer(output, target)
losses.update(loss.item(), data.size(0))
outputs = torch.argmax(output, dim=1) # B x Z x Y x X
outputs_np = outputs.data.cpu().numpy() # B x Z x Y x X
labels_np = target.data.cpu().numpy() # B x Z x Y x X
multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
multi_dices.append(multi_dice)
optimizer.zero_grad() # 梯度归零
loss.backward() # 反向传播
optimizer.step() # 优化器走一步
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
multi_dices_np = np.array(multi_dices)
mean_multi_dice = np.mean(multi_dices_np)
# plot progress
bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}| LR: {lr:.6f}'.format(
epoch=epoch,
batch=batch_index + 1,
size=len(train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss=losses.avg,
dice=mean_multi_dice,
lr=optimizer.param_groups[0]['lr']
)
bar.next()
bar.finish()
return losses.avg # 返回平均损失
3.3. Verification part
The verification part is basically the same as the training part, except that:
- In the training phase,
model.train()
, and in the validation phase, it is requiredmodel.eval()
- Gradient regression is not performed to update the model during the verification phase, and the loss is only for statistical use.
Everything else is almost the same, the code is as follows:
def valid_model(model, device, test_loader, ce_layer, epoch): # 加了个test 1是想打印时好看(区分valid和test) 2是test要打印图,需要特别设计
config = Configuration()
# 模型训练-----调取方法
model.eval() # 用来验证或测试的
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
multi_dices = list()
bar = Bar('Processing valid ', max=len(test_loader))
with torch.no_grad(): # 不进行 梯度计算(反向传播)
for batch_index, (data, target) in enumerate(test_loader): # 枚举batch索引,(图,标签)
data_time.update(time.time() - end)
data, target = data.to(device), target.to(device)
output = model(data)
loss = ce_layer(output, target)
losses.update(loss.item(), data.size(0))
outputs = torch.argmax(output, dim=1) # B x C x Z x Y x X > B x Z x Y x X
outputs_np = outputs.data.cpu().numpy() # B x Z x Y x X
labels_np = target.data.cpu().numpy() # B x Z x Y x X
multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
multi_dices.append(multi_dice)
multi_dices_np = np.array(multi_dices)
mean_multi_dice = np.mean(multi_dices_np)
std_multi_dice = np.std(multi_dices_np)
# plot progress
bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}'.format(
epoch=epoch,
batch=batch_index + 1,
size=len(test_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss=losses.avg,
dice=mean_multi_dice
)
bar.next()
bar.finish()
return losses.avg, mean_multi_dice
3.4. Training experience
In the 3D UNet
model article, we mentioned:
During the training phase of the model, there is no need to add
sigmoid
orsoftmax
operations at the end. This is only required during the inference phase.
However, on the other handCrossEntropyLoss
, although it is not in the model, the definition uses the sigmoid
or softmax
operation, but When he calculated the loss function, he secretly used the sigmoid
or softmax
operation.
If you do not use CrossEntropyLoss
and use Dice loss
, then before calculating the loss function, you need to first make a model output similar to < a normalization operation for i=3>? CrossEntropyLoss
According to my own training, I found that if no normalization operation is performed before calculating Dice loss
, the gradient will easily disappear, which means it cannot converge and it will be difficult to train. This may play a standardizing role in time sigmoid
or softmax
and make model training easier. As for other reasons and phenomena, further additions need to be discovered.
4. Summary
Someone commented last time that the complete code was required, and this will definitely be released in the end. In a single article, the complete code has basically been posted. After a little troubleshooting, there should be no problem. Even if there are any problems, they are all simple and small problems. I have verified this.
For some beginners who don’t understand, such as python
’sos
file operation library, it is recommended to read other articles and write this part Complete your knowledge before continuing to study.
If an error occurs, check the error prompts for modification suggestions as soon as possible, or follow the prompts to locate the wrong place and make targeted modifications. If that doesn't work, just use Baidu. Most of the problems have already been encountered by people on the Internet. In the end, if it doesn't work, just leave a message in the comment area, and we can solve the problem together, which will be faster.
Finally, there is still one prediction chapter left, so let’s continue reading.