Detección de imágenes de prominencia de varias pérdidas

Detección de imágenes salientes (recorte)

pérdida de borde astuto

Pérdida=Saliente+Lglobalinserte la descripción de la imagen aquí

Canny obtiene el borde de la imagen predicha y la imagen de la máscara, y luego se calcula la pérdida.

bce_loss = nn.BCELoss(size_average=True)  #二分类交叉熵

def opencv(images):
    for i in range(images.shape[0]):
        image = images[i, 0, :, :]
        image = image // 0.5000001 * 255   # 二值化
        image_2 = image.cpu().detach().numpy()
        image_2 = image_2.astype(np.uint8)
        img = cv2.Canny(image_2, 30, 150)
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img.type = torch.float32
        if i != 0:
            if i != 1:
                img_final = torch.cat((img_final, img), 0)
            else:
                img_final = torch.cat((img_first, img), 0)
        else:
            img_first = img
    return img_final / 255

loss=bce_loss(opencv(pre),opencv(label))

复制代码

Dos formas de agregar esta pérdida de borde: 1. Cada salida se cuenta 2. Solo se cuenta una salida

Pérdida de IoU

Destacar la perspectiva

def _iou(pred, objetivo):

b = pred.shape[0]
IoU = 0.0
for i in range(0,b):
    Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
    Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
    IoU1 = Iand1/Ior1

    IoU = IoU + (1-IoU1)    #因为要算的是错误的大小,所以要1-IoU

return IoU/b
复制代码

Pérdida de borde en EGNet

Documento: http://mftp.mmcheng.net/Papers/19ICCV_EGNetSOD.pdf GitHub: https://github.com/JXingZhao/EGNet/En inserte la descripción de la imagen aquí EGNet, la etiqueta binarizada con un umbral de 0,5 se incluye en la función de pérdida de borde. la parte de pérdida normal en el cálculo es una etiqueta binarizada con 0 como umbral

def load_edge_label(im):
    """
    pixels > 0.5 -> 1
    Load label image as 1 x height x width integer array of label indices.
    The leading singleton dimension is required by the loss.
    """
    label = np.array(im, dtype=np.float32)
    if len(label.shape) == 3:
        label = label[:,:,0]
    # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST)
    label = label / 255.
    label[np.where(label > 0.5)] = 1.   #  0.5当做阈值
    label = label[np.newaxis, ...]
    return label

def EGnet_edg(d,labels_v):
    target=load_edge_label(labels_v)
    # assert(d.size() == target.size())
    pos = torch.eq(target, 1).float()
    neg = torch.eq(target, 0).float()
    # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float()

    num_pos = torch.sum(pos)
    num_neg = torch.sum(neg)
    num_total = num_pos + num_neg

    alpha = num_neg  / num_total
    beta = 1.1 * num_pos  / num_total
    # target pixel = 1 -> weight beta
    # target pixel = 0 -> weight 1-beta
    weights = alpha * pos + beta * neg

    return F.binary_cross_entropy_with_logits(d, target, weights, reduction=None)
    
复制代码

En el documento, la parte de binarización de etiquetas se agrega a la parte del conjunto de datos. Quiero lograrlo directamente a través del procesamiento posterior, pero hay algunos errores extraños. . . Pegue la sección load_edge_label original a continuación:

def load_edge_label(pah):
    """
    pixels > 0.5 -> 1
    Load label image as 1 x height x width integer array of label indices.
    The leading singleton dimension is required by the loss.
    """
    if not os.path.exists(pah):
        print('File Not Exists')
    im = Image.open(pah)
    label = np.array(im, dtype=np.float32)
    if len(label.shape) == 3:
        label = label[:,:,0]
    # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST)
    label = label / 255.
    label[np.where(label > 0.5)] = 1.
    label = label[np.newaxis, ...]
    return label
复制代码

Luego, lea el edge_label procesado cuando lea la carga de datos y tírelo a EGnet_edg para el cálculo.

El documento ajusta el tamaño de la pérdida al calcular la pérdida nAveGrad=10

sal_loss = (sum(sal_loss1) + sum(sal_loss2)) / (nAveGrad * self.config.batch_size)
复制代码

Supongo que te gusta

Origin juejin.im/post/7085137835119869960
Recomendado
Clasificación