[pytorch] 图像识别之mixup/cutout/Margin loss....简单实现

本人kaggle分享链接:https://www.kaggle.com/c/bengaliai-cv19/discussion/128592

Mixup

from torchtoolbox.tools import mixup_data, mixup_criterion

alpha = 0.2
for i, (data, labels) in enumerate(train_data):
    data = data.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)

    data, labels_a, labels_b, lam = mixup_data(data, labels, alpha)
    optimizer.zero_grad()
    outputs = model(data)
    loss = mixup_criterion(Loss, outputs, labels_a, labels_b, lam)

    loss.backward()
    optimizer.update()
 

Cutout

from torchvision import transforms
from torchtoolbox.transform import Cutout

_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    Cutout(),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4),
    transforms.ToTensor(),
    normalize,
])

ArcLoss
CosLoss
L2Softmax

from torchtoolbox.nn.loss import ArcLoss, CosLoss, L2Softmax

reference:https://github.com/PistonY/torch-toolbox

发布了342 篇原创文章 · 获赞 794 · 访问量 178万+

猜你喜欢

转载自blog.csdn.net/u014365862/article/details/104216265
今日推荐