Test time augmentation(TTA)

Test time augmentation

数据扩充是模型训练期间通常使用的一种方法,它使用来自训练数据集的样本的修改副本来扩充训练集。
数据增强通常使用图像数据执行,其中使用一些执行的图像处理技术(如缩放、翻转、移位等)创建训练数据集中的图像副本。
人工扩展的训练数据集可以产生更熟练的模型,因为深度学习模型的性能通常会随着训练数据集的大小而不断扩展。此外,训练数据集中图像的修改或增强版本可以帮助模型以不改变其位置、光线等的方式提取和学习特征。
Test time augmentation(简称TTA)是对测试数据集进行数据扩展的应用程序。
具体来说,它涉及到在测试集中创建每个图像的多个增强副本,让模型为每个图像做出预测,然后返回这些预测的集合。
选择增强是为了让模型有最佳的机会正确地对给定的图像进行分类,而且模型必须对图像进行预测的副本的数量通常很少,比如少于10或20个。
通常,只执行一个简单的测试时间扩展,比如移位、裁剪或图像翻转。
我们还通过图像的水平翻转来增加测试集;对原始图像和翻转后的图像进行软最大类后验平均,得到图像的最终得分。

How to use

1.借助pytorch_toolbelt

from pytorch_toolbelt.inference import tta

# Truly functional TTA for image classification using horizontal flips:
logits = tta.fliplr_image2label(model, input)

# Truly functional TTA for image segmentation using D4 augmentation:
logits = tta.d4_image2mask(model, input)

# TTA using wrapper module:
tta_model = tta.TTAWrapper(model, tta.fivecrop_image2label, crop_size=512)
logits = tta_model(input)

2.自己动手

import torch

# 水平翻转
def flip_horizontal_tensor(batch):
    columns = batch.data.size()[-1]
    return batch.index_select(-1, torch.LongTensor(list(reversed(range(columns)))).cuda())

在这里插入图片描述

#   垂直翻转
def flip_vertical_tensor(batch):
    rows = batch.data.size()[-2]
    return batch.index_select(-2, torch.LongTensor(list(reversed(range(rows)))).cuda())

在这里插入图片描述

发布了33 篇原创文章 · 获赞 3 · 访问量 5539

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/104558980