对 Spatial Transformer Networks(空间变换网络) 的思考

论文地址: https://arxiv.org/abs/1506.02025

这几天看了下stn,大概写一写吧。说实话,这个东西思想倒是蛮有意思的,但是实际用起来效果不好说,至少在我想要应用的场景下效果不怎么样。

这里先写论文的思路,再写一下我做的一些实验与相应的思考。

STN 目标

我们知道,CNN推动了计算机视觉的发展,但是还是有一些缺陷。在“Visualizing and Understanding Convolutional Networks”这篇论文里其实就有所论述,实际上,卷积神经网络有一定的空间不变性,包括平移,缩放等。

但实际上对于一些变换比较大的问题仍然存在瓶颈,STN的就是为了解决这些问题而提出的一个网络。

STN的结构

STN的结构并不复杂,首先还是拿论文里的图来说一下:

左边的U是输入图片(feature_map),在论文中,作者表示stn可以加在卷积网络的任何位置,所以这里可以是输入图像也可以是经过若干层卷积的特征图。

接下来是一个称为Localisation net的网络,这个网络可以是任意的结构,全连接或者卷积网络都是可以的,这个网络的目标就是通过给定的输入,学习一组用于变换的参数。这个参数的数量是6个,也就是说,Localisation net有6个输出。

接下来用这6个参数对原始的输入做线性变换,生成一个新的输出V,这个输出的channel数与U的channel数相同,并且各个channel所做的变换也是相同的。


这个变换并不复杂,不过需要一些线性代数的知识,我们把输入的矩阵的位置用(x, y)表示,我们就可以得到一个(w, h, 2)的坐标矩阵,将这个坐标增加一维,填1,就得到了一个(w, h, 3)的坐标矩阵。至于为什么要填这个1,主要是仿射变换的需要,旋转、平移、缩放是不需要的,具体细节可以去看一下这些操作的矩阵实现,这里就不细说了。然后我们学到的6个参数变换为(3,2)的矩阵,进行矩阵乘法,就可以得到变换完的坐标,同样是(w, h, 2)的矩阵。

获得新的坐标以后,下一步就是进行采样,把对应的内容填充到矩阵V中,这一步不难理解,可以用双线性插值来完成。

可求导的采样方法

上一步说了,要使用采样的方法把变换后的坐标映射到矩阵V上。理论上哪种方法都可以,但是有一点比较关键,因为要训练网络,需要产生梯度,像最邻近插值这种方法,实际上是不能产生梯度的,因为它只是把矩阵U的内容移动到V上面,本身并没有任何变化。所以作者使用了双线性插值的方法,由于一个点的位置,是由其周围四个点的值计算插值而来,就“产生”了新的值,导致梯度的产生,有了网络训练的基本条件。

效果与疑问

论文里发了个效果视频:
https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view
视频里的效果看起来还是很好的,能够把手写数字矫正,看起来还是很神奇的。

我也尝试使用stn训练旋转的文字,不过效果并不好,也不存在说能够把旋转文字矫正这一现象,所以我产生了一些疑惑,这个网络真的能矫正旋转的文字吗?

后来想了一下,这一点其实理论上并不能说通,因为文字的“正”方向实际上是人类的先验知识,这一点实际上并不能从数据和网络上体现,也就是说,网络并不具备这部分知识,它本身并不知道文字的正反, 那又如何能“矫正”文字呢?另一方面,无论卷积网络还是全连接网络,其实也不存在旋转不变性,Localisation net又怎么学习旋转的性质呢?

带着疑问,我做了几个实验,想要复现论文的结果。
找了几个stn的实现,其实大同小异了,尝试了一下:https://github.com/zsdonghao/Spatial-Transformer-Nets的实现。

发现效果并不理想,stn看起来只是找到了文字并把它填充了图片的视野,并没有旋转的效果:
输入:

stn矫正后:

这个实现不知道是不是有些问题,调用了tensorlayer的一些api,不太好说。找了一个pytorch的实现,是pytorch官方教程中的:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html#depicting-spatial-transformer-networks

这个代码有点问题的,没有对输入数据进行旋转,直接stn,得不到我想要的效果,所以稍微改了一下,把输入数据先旋转,然后再扔进去训。下面是完整代码:

# coding: utf-8

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
import random

plt.ion()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=128, shuffle=True, num_workers=4)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)


optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(rotate_image_tensors(data))
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test():
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(rotate_image_tensors(data))

            # sum up batch loss
            test_loss += F.nll_loss(output, target, size_average=False).item()
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

def rotate_image(img):
    rows, cols, ch = img.shape
    M = cv2.getRotationMatrix2D((cols/2, rows/2), random.randint(-135, -45), 1)
    dst = cv2.warpAffine(img, M, (cols, rows))
    return dst.reshape((28, 28, 1))

def rotate_image_tensors(image_tensors):
    batch_size, c, h, w = image_tensors.shape
    sp_imgs = torch.chunk(image_tensors, batch_size, dim=0)
    # print(sp_img[0].shape)
    rotated_imgs = []
    for single_img in sp_imgs:
        img = single_img.squeeze(0).numpy().transpose(1, 2, 0)
        img = rotate_image(img)
        r_img = torch.from_numpy(img.transpose(2, 0, 1))
        rotated_imgs.append(r_img)
        # print(r_img.shape)
    res = torch.stack(rotated_imgs, dim=0)
    return res


def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

def visualize_stn():
    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(test_loader))[0].to(device)

        input_tensor = data.cpu()
        input_tensor = rotate_image_tensors(input_tensor)
        transformed_input_tensor = model.stn(input_tensor).cpu()

        # transformed_input_tensor = rotate_image_tensors(transformed_input_tensor)
        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))

        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))

        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(out_grid)
        axarr[1].set_title('Transformed Images')

# visualize_stn()
for epoch in range(1, 20):
    train(epoch)
    test()

visualize_stn()

plt.ioff()
plt.savefig('./result.png', format='png')
# plt.show()

接下来是训练结果:

上图是旋转了(-90, 90)度,看起来并不是完全没有效果,目测大部分图都被转到了差不多0,-45,45的角度。
在这里插入图片描述
接下来,是旋转了(0, 90)度,这就有趣了,输出的图基本都被转到了45度。

那么旋转了(45, 135)度又如何呢?输出的图都转到了90度。

旋转了(0, 360)度貌似没什么规律了。

上面几个实验基本上也印证了我的猜测。首先,网络肯定是没有这个字被转了多少度这个知识的,至于为什么文字都被转到了范围中间的角度了呢?我有个猜测,首先是和神经网络的特点有关系,记得有个论文猜测,神经网络会先去学简单的特征,然后随着训练的进行,才会学比较困难的特征,对于stn可能也差不多,当把这些字转到中间的时候,对于识别来说更简单。

简单来说,Localisation net比后面的分类网络复杂,那么Localisation net可能会倾向于把文字转到一个或几个固定角度(为什么是中间某个角度,大概是因为转这些更“容易”?)。分类网络如果更复杂,Localisation net比较简单,可能就会出现第一个实验的情况,Localisation net只是把文字这个特征缩放了一下,并没有旋转。感觉由于随机梯度下降这个方法,每次都是微调,导致整个神经网络想用最小的能量达到目标,才导致了这个结果。

那么stn到底有没有用呢?应该还是有一定用处的,但是是不是能用来矫正旋转文字呢?看起来小的角度是可以的,但它能矫正并不是因为它知道哪些是正的,只是因为它恰好会把字旋转到中间那个角度而已。

发布了443 篇原创文章 · 获赞 149 · 访问量 55万+

猜你喜欢

转载自blog.csdn.net/qian99/article/details/89921112
今日推荐