cout un nuevo método de regularización

COSTE

Muesca [1] es un nuevo método de regularización. El principio es la parte aleatoria de la imagen perdida durante el entrenamiento, esto puede mejorar la robustez del modelo. Es el problema de la oclusión objeto de origen en las tareas de visión por ordenador a menudo encontrados. Por recorte generar algunos objetos oscurecidos similares, no sólo puede hacer que el modelo funcione mejor en la cara de la oclusión, puede hacer que el modelo al tomar decisiones más consideración con el medio ambiente (contexto).

La implementación de cout

1. propio código

import torch
import numpy as np

class Cutout(object):
 """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
 def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

 def __call__(self, img):
 """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

 		for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

 return img

2. Con una biblioteca de terceros

from albumentations import Cutout
import matplotlib.pyplot as plt
import cv2

transform = Compose([
		             Cutout(num_holes=30, max_h_size=7, max_w_size=7, fill_value=128, p=1)
				    ])
images = cv2.imread("./data/input/images/00000060_000.png")
images2 = transform(image=images)["image"]
plt.subplot(121)
plt.imshow(images)
plt.subplot(122)
plt.imshow(images2)
plt.show()

Los resultados muestran:
Aquí Insertar imagen Descripción

Referencia

[1] https://arxiv.org/pdf/1708.04552.pdf
[2] https://zhuanlan.zhihu.com/p/66080948

Publicado 33 artículos originales · ganado elogios 3 · Vistas 5536

Supongo que te gusta

Origin blog.csdn.net/weixin_42990464/article/details/104640998
Recomendado
Clasificación