图像数据增强之Mixup

原理:不同数据进行像素点/特征融合;
作用:增加返回能力;鲁棒性;性能提高。

以下直接上代码:

# 如何使用请详细参考MS_TCN代码
import torch
import numpy as np

# -- mixup data augmentation
# from https://github.com/hongyi-zhang/mixup/blob/master/cifar/utils.py
def mixup_data(x, y, alpha=1.0, soft_labels = None, use_cuda=False):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''

    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]		# 同一batch中的样本进行特征间的相互融合
    
# Example:
# 假设batch_x.shape=[2,3,112,112],batch_size=2时,
# 如果index=[0,1]的话,则可看成mixed_x=lam*[[0,1],3,112,112]+(1-lam)*[[0,1],3,112,112]=		 [[0,1],3,112,112],即为同类混合
# 如果index=[1,0]的话,则可看成mixed_x=lam*[[0,1],3,112,112]+(1-lam)*[[1,0],3,112,112]=[batch_size,3,112,112],即为异类混合
    
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

参考资料:
1.https://blog.csdn.net/sinat_36618660/article/details/101633504;
2.https://blog.csdn.net/u013841196/article/details/81049968。

猜你喜欢

转载自blog.csdn.net/weixin_41807182/article/details/126939079