busternet:Detecting copy-move image forgery with source/target localization

https://www.jianshu.com/p/d1ddf1a15941https://www.jianshu.com/p/d1ddf1a15941https://github.com/ntdat017/BusterNet_pytorchhttps://github.com/ntdat017/BusterNet_pytorch

wuyue在篡改上经常发文章,还是值得关注的,不过通常都是自然场景篡改类的文章,和文档场景差别挺大的。这篇文章由于时间较久,我尽量比较精简的来记录一下。首先论文本身是用来处理同图检测的copy-move这种,也就是在原图中找一块区域,然后复制粘贴到原图中,所以任务设计上也是找出原图中的copy和orign两处,既找出篡改地方,也找出在图中复制的哪一块。它的数据集制作比较简单,就是找coco上的目标在图上粘贴就可以,所以说这种任务很容易过拟合,在自己造的数据集上效果很好,一迁移拉胯,而且造数据的方式过于粗暴,网络也没有学到什么很好的信息。

第一张图是核心图,首先看上面,第一是原图,第二张是gt,第三张是预测图,先看第二张gt,有三个部分,红色是篡改区域,绿色是相似图区域,蓝色是背景区域,标签制作如下,这三个区域其实对应了第一张图的三个监督的loss,

fusion_preds, mani_preds, simi_preds = self.model(imgs)
# gt:红色是篡改,绿色是相似图,蓝色是背景 rg b
simi_gts = (1 - gts[:, 2, :, :]).type(torch.float)  # 把b减掉,剩下都是相似图
mani_gts = gts[:, 0, :, :].type(torch.float)  # r维度就是篡改图
_, fusion_gts = gts.max(dim=1)  # 三个维度的最大值,最终应该是一个0.1.2的图

看第一张图的三个分支,首先红色框起来是篡改检测分支,这个分支类似于我们传统做风格的定位区域分支,是一个unet结构,先降维在升维,假定输入图是256x256x3,使用VGG16的前4个blocks,输出是16x16x512,在经过上采样,主要是upsamplingbilinear2d和inception模块,这里的上采样有一个三分支的融合。最后接一个conv2d(6,1)转成一维,在接一个sigmoid,此处对应的是mani_gts,是一个只有篡改的二值输出。其次是蓝色框,这个分支是检测相似度的,网络里面有个self-correlation-percentpooling模块,这里个分支是三部分,首先也是vgg16提特征,其次是自关联感知的池化操作,之后是上采样。中间这个模块代码如下:

class CorrelationPercPooling(nn.Module):
    '''Custom Self-Correlation Percentile Pooling Layer
    '''

    def __init__(self, nb_pools=256, **kwargs):
        super(CorrelationPercPooling, self).__init__()
        self.nb_pools = nb_pools

        n_maps = 16 * 16

        if self.nb_pools is not None:
            self.ranks = torch.floor(torch.linspace(0, n_maps - 1, self.nb_pools)).type(torch.long)
        else:
            self.ranks = torch.range(1, n_maps, dtype=torch.long)

    def forward(self, x):
        '''
            x_shape: (n, c, h, w)
        '''
        n_bsize, n_feats, n_cols, n_rows = x.shape  # 16,16,512
        n_maps = n_cols * n_rows
        x_3d = x.reshape(n_bsize, n_feats, n_maps)

        x_corr_3d = torch.matmul(x_3d.transpose(1, 2), x_3d) / n_feats  # 皮尔逊相关系数
        x_corr = x_corr_3d.reshape(n_bsize, n_maps, n_cols, n_rows)

        # ranks = ranks.to(devices)
        x_sort, _ = torch.topk(x_corr, k=n_maps, dim=1, sorted=True)

        x_f1st_sort = x_sort.permute(1, 2, 3, 0)
        x_f1st_pool = x_f1st_sort[self.ranks]  # 16,16,256,topk取前256
        x_pool = x_f1st_pool.permute(3, 0, 1, 2)

        return x_pool

它其实是个通用模块,本身我觉得最可以拿出来的就是这个模块,这个模块也可以嵌入到通用的篡改文档检测中,大概的思路是输入是16x16x512,输出是16x16x256,中间先经过了一个皮尔逊系数计算相关性,然后topk取了前256个维度。这个模块的分支是相似度检测,是simi_gts = (1 - gts[:, 2, :, :]),注意也是一个sigmoid输出的二分类,相似区域其实是两处,包含了篡改和相似的这样一个gt图。最后是一个融合模块,融合模块的损失是交叉熵,前两个分支都是bce,交叉上是因为此时算上背景是一个三分类,不是二分类了,融合分支的输出是softmax2d,然后在直接插值变成原图,在原图上用分类计算。

代码很简单,复现的作者的代码也很精炼。基本就是一个三分支网络,三个loss,loss设计兼顾了篡改,同图像素相似和三分类任务,是对特定任务做设计。

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/126630498