cout一种新的正则化方法

COUT

Cutout[1]是一种新的正则化方法。原理是在训练时随机把图片的一部分减掉,这样能提高模型的鲁棒性。它的来源是计算机视觉任务中经常遇到的物体遮挡问题。通过cutout生成一些类似被遮挡的物体,不仅可以让模型在遇到遮挡问题时表现更好,还能让模型在做决定时更多地考虑环境(context)。

The implementation of cout

1.自己码

import torch
import numpy as np

class Cutout(object):
 """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
 def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

 def __call__(self, img):
 """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

 		for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

 return img

2.借助第三方库

from albumentations import Cutout
import matplotlib.pyplot as plt
import cv2

transform = Compose([
		             Cutout(num_holes=30, max_h_size=7, max_w_size=7, fill_value=128, p=1)
				    ])
images = cv2.imread("./data/input/images/00000060_000.png")
images2 = transform(image=images)["image"]
plt.subplot(121)
plt.imshow(images)
plt.subplot(122)
plt.imshow(images2)
plt.show()

效果展示:
在这里插入图片描述

Reference

[1] https://arxiv.org/pdf/1708.04552.pdf
[2] https://zhuanlan.zhihu.com/p/66080948

发布了33 篇原创文章 · 获赞 3 · 访问量 5536

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/104640998
今日推荐