pytorch 数据增强cutmix的实现

摘要

cutmix和mixup是一种比较重要的数据增强手段,普通的数据增强也只是在照片上修改,增强了对网络提取特征图的能力,cutmix这种就是混合label,增强了fc的学习能力。

cutmix的思想,

只要是用过二张照片,随机的截取一部分,然后换位置,导致label也发生变化。
在这里插入图片描述
在这里插入图片描述
本来有四种花,我放上去的图片是二个换位置比较明显的图片,照片对应发生改变,label也改变了,变成了有小数的,这里大家可能有疑问。我这里是在线下做的测试,在实际我们运行程序的过程中,原本分类的label的1确实会变成小数,但是位置没有变化,因为1代表的是下标,下标不为0的就是分类想要预测的结果,及时是小数,下标还在原来的位置,所以分类的label还是没有变化的。

读取照片

第一步就是将一个文件夹下的照片读取出来,模仿dataloader这个批次的加载。保证你的 data下有四张以上的照片,大小必须一致。

import glob
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10,10]
import cv2

# Path to data
data_folder = f"./data/"

# Read filenames in the data folder
filenames = glob.glob(f"{data_folder}*.jpg")
# Read first 10 filenames
image_paths = filenames[:4]

# Display a sample image
# plt.imshow(cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)); plt.show();


image_batch = []
image_batch_labels = []
n_images = 4
print(image_paths)
for i in range(4):
    image = cv2.cvtColor(cv2.imread(image_paths[i]), cv2.COLOR_BGR2RGB)
    image_batch.append(image)
    
image_batch_labels=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])
    

for i in range(2):
     for j in range(2):
         plt.subplot(2,2,2*i+j+1)
         plt.imshow(image_batch[2*i+j])
plt.show()
c=image_batch[0]
print(c.shape)

在这里插入图片描述
这里我输出了shape,是(500,500,3)但是有4张照片,所以应该是(4,500,500,3)的shape,和dataloader还是不一样的。这里我只是测试。

随机截取

这里的截取只要不超过边界就行,很容易看懂

def rand_bbox(size, lamb):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lamb)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

换位置

lam = np.random.beta(beta, beta)
    rand_index = np.random.permutation(len(image_batch)) #产生一个换位置的索引【1,0,2,3】
    target_a = image_batch_labels
    target_b = np.array(image_batch_labels)[rand_index]
    print('img.shape',image_batch[0].shape)
    bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
    print('bbx1',bbx1)
    print('bby1',bby1)
    print('bbx2',bbx2)
    print('bby1',bby1)
    image_batch_updated = image_batch.copy()   #前面都是list形式,所以切片操作必须是array
    image_batch_updated=np.array(image_batch_updated)
    image_batch=np.array(image_batch)
    image_batch_updated[:, bbx1:bby1, bbx2:bby2, :] = image_batch[rand_index, bbx1:bby1, bbx2:bby2, :]
    
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image_batch.shape[1] * image_batch.shape[2])) #label对应也要改变
    label = target_a * lam + target_b * (1. - lam)

全部代码

import glob
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10,10]
import cv2

# Path to data
data_folder = f"./data/"

# Read filenames in the data folder
filenames = glob.glob(f"{data_folder}*.jpg")
# Read first 10 filenames
image_paths = filenames[:4]

# Display a sample image
# plt.imshow(cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)); plt.show();


image_batch = []
image_batch_labels = []
n_images = 4
print(image_paths)
for i in range(4):
    image = cv2.cvtColor(cv2.imread(image_paths[i]), cv2.COLOR_BGR2RGB)
    image_batch.append(image)
image_batch_labels=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])
# for i in range(2):
#     for j in range(2):
#         plt.subplot(2,2,2*i+j+1)
#         plt.imshow(image_batch[2*i+j])
# plt.show()
def rand_bbox(size, lamb):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lamb)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
image = cv2.cvtColor(cv2.imread(image_paths[0]), cv2.COLOR_BGR2RGB)
# Crop a random bounding box
lamb = 0.3
size = image.shape
print('size',size)
# bbox = rand_bbox(size, lamb)

# # Draw bounding box on the image
# im = image.copy()
# x1 = bbox[0]
# y1 = bbox[1]
# x2 = bbox[2]
# y2 = bbox[3]
# cv2.rectangle(im, (x1, y1), (x2, y2), (255, 0, 0), 3)
# plt.imshow(im);
# plt.title('Original image with random bounding box')
# plt.show();

# Show cropped image
# plt.imshow(image[y1:y2, x1:x2]);
# plt.title('Cropped image')
# plt.show()
def generate_cutmix_image(image_batch, image_batch_labels, beta):

    # generate mixed sample
    lam = np.random.beta(beta, beta)
    rand_index = np.random.permutation(len(image_batch))
    target_a = image_batch_labels
    target_b = np.array(image_batch_labels)[rand_index]
    print('img.shape',image_batch[0].shape)
    bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
    print('bbx1',bbx1)
    print('bby1',bby1)
    print('bbx2',bbx2)
    print('bby1',bby1)
    image_batch_updated = image_batch.copy()
    image_batch_updated=np.array(image_batch_updated)
    image_batch=np.array(image_batch)
    image_batch_updated[:, bbx1:bby1, bbx2:bby2, :] = image_batch[rand_index, bbx1:bby1, bbx2:bby2, :]
    
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image_batch.shape[1] * image_batch.shape[2]))
    label = target_a * lam + target_b * (1. - lam)
    
    return image_batch_updated, label
# image_batch=np.array(image_batch)
# image_batch_updated = image_batch.copy()
# c=[1,0,2,3]
# mm=np.array(image_batch_updated)
# mm[:, 10:200, 10:200, :] = image_batch[c, 10:200, 10:200, :]
# Generate CutMix image
# Let's use the first image of the batch as the input image to be augmented
input_image = image_batch[0]
image_batch_updated, image_batch_labels_updated = generate_cutmix_image(image_batch, image_batch_labels, 1.0)

# Show original images
print("Original Images")
for i in range(2):
    for j in range(2):
        plt.subplot(2,2,2*i+j+1)
        plt.imshow(image_batch[2*i+j])
plt.show()


# Show CutMix images
print("CutMix Images")
for i in range(2):
    for j in range(2):
        plt.subplot(2,2,2*i+j+1)
        plt.imshow(image_batch_updated[2*i+j])
plt.show()

# Print labels
print('Original labels:')
print(image_batch_labels)
print('Updated labels')
print(image_batch_labels_updated)

我写的这种是方便大家观看,跟直接在pytorch模型里的代码修改还是有点差距的,
线上的我还未测试,大家可以按照下面的代码对比操作一下。

pytorch使用

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
def cutmix(data, targets1, targets2, targets3, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]
    return data, targets

def mixup(data, targets1, targets2, targets3, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    data = data * lam + shuffled_data * (1 - lam)
    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]

    return data, targets


def cutmix_criterion(preds1,preds2,preds3, targets):
    targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
    criterion = nn.CrossEntropyLoss(reduction='mean')
    return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)

def mixup_criterion(preds1,preds2,preds3, targets):
    targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
    criterion = nn.CrossEntropyLoss(reduction='mean')
    return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)
for i, (image_id, images, label1, label2, label3) in enumerate(data_loader_train):
            images = images.to(device)
            label1 = label1.to(device)
            label2 = label2.to(device)
            label3 = label3.to(device)
            # print (image_id, label1, label2, label3)

            if np.random.rand()<0.5:
                images, targets = mixup(images, label1, label2, label3, 0.4)
                output1, output2, output3 = model(images)
                loss = mixup_criterion(output1,output2,output3, targets) 
            else:
                images, targets = cutmix(images, label1, label2, label3, 0.4)
                output1, output2, output3 = model(images)
                loss = cutmix_criterion(output1,output2,output3, targets) 

总结

github也有完整版的修改,可以去搜索测试一下,这种特殊的数据增强对分类修改还是比较容易的,目标检测中基本都是不考虑label,只是组合新照片。

猜你喜欢

转载自blog.csdn.net/cp1314971/article/details/106612060
今日推荐