In-depth understanding of STN in registration tasks

foreword

I have been doing registration work for some time, and I have been writing articles since I had an idea. I use each module as a black box in my work. It wasn't until recent in-depth research that I discovered how shallow my knowledge is, so I decided to slowly understand the task of registration from the most basic level. This article writes my simple understanding of STN.

Post the most basic STN code:

class SpatialTransform(nn.Module):
    def __init__(self):
        super(SpatialTransform, self).__init__()
    def forward(self, x,flow,sample_grid):
        sample_grid = sample_grid+flow  
        size_tensor = sample_grid.size() #3D c,h,w
        #此处将新坐标系归一化
        sample_grid[0,:,:,:,0] = (sample_grid[0,:,:,:,0]-((size_tensor[3]-1)/2))/size_tensor[3]*2
        sample_grid[0,:,:,:,1] = (sample_grid[0,:,:,:,1]-((size_tensor[2]-1)/2))/size_tensor[2]*2
        sample_grid[0,:,:,:,2] = (sample_grid[0,:,:,:,2]-((size_tensor[1]-1)/2))/size_tensor[1]*2  
        
        image = torch.nn.functional.grid_sample(x, sample_grid,mode = 'bilinear')
        
        return image

First of all, what exactly is the input x, flow, sample_grid?

x is an image, we assume that its coordinate system
is in the X space flow is a displacement field output through x and y, the displacement of this displacement field is limited to [-1,1]
sample_grid is an initialization coordinate system, It stores the coordinates of a graph

Why is the displacement of the displacement field limited to [-1,1]?

To be clear, not all work constrains the displacement field to [-1,1], such as VoxelMorph 1 .
Recently I've seen work on constrained displacement fields with ICNET, SYM, etc. 2 3 .

#ICNet/Code/Models
def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
                bias=False, batchnorm=False):
        if batchnorm:
            layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.BatchNorm3d(out_channels),
                nn.Tanh())
        else:
            layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.Tanh())
        return layer

First of all, in terms of function, it is implemented with Tanh, which normalizes the final output layer and limits the flow to [-1,1].
Then use flow*range_flow to expand the range of flow.

In fact, range_flow has done two things. One is to limit the range of flow (thanks to tanh), and the other is that the range is quite large.
Based on this, range_flow will make the flow of the network output more stable, but the problem of not adding tanh+range_flow will not be too big.

Internal operation of STN

The ultimate goal of flow is to make the image A transformed by it similar to the fixed image B, so the question returns to how it transforms image A.

first step

sample_grid = sample_grid+flow

sample_grid is the identity transformation grid, defined as the initial x in some diffeomorphism articles .
It should be clear here that the sample_grid has not been normalized to the coordinate system, and the same flow does not know what range it should fall in before training. sample_grid+flow gets the deformed coordinate system.

sample_grid+flow gets the coordinate system distorted by flow.

second step

sample_grid[0,:,:,:,0] = (sample_grid[0,:,:,:,0]-((size_tensor[3]-1)/2))/size_tensor[3]*2
sample_grid[0,:,:,:,1] = (sample_grid[0,:,:,:,1]-((size_tensor[2]-1)/2))/size_tensor[2]*2
sample_grid[0,:,:,:,2] = (sample_grid[0,:,:,:,2]-((size_tensor[1]-1)/2))/size_tensor[1]*2  
        
image = torch.nn.functional.grid_sample(x, sample_grid,mode = 'bilinear')

Normalize the warped coordinate system (because torch.nn.functional.grid_sample needs to normalize the coordinate system), and then interpolate.
Here I can't help but ask a question, wouldn't it be good to normalize the sample_grid at the very beginning?
In fact, if the flow is normalized at the beginning, there is no way to ensure that the warped coordinate system is in [-1,1]

As you can see, we interpolate the pixel values ​​of x (Moving image) into the warped coordinate system. So far the interpolation is over.

The moving image interpolates the coordinate system distorted by the flow.

Is the flow Moving->Fixed or Fixed->Moving?

The most intuitive idea should be the Flow of Moving->Fixed, but it is actually the Flow of Fixed-Moving. I was shocked to know this at first, but it is easy to find the tricks through examples.

For example, a point in Moving before distortion is (5,5), and the value of flow at this point is (3,3), which means that this point becomes (8,8) after distortion.
Assuming that this (8,8) is inserted into the moving pixel value, there is a problem, because the (8,8) in the moving is not the (5,5) before the distortion, so the interpolation is still the original image.
Therefore, the point before the distortion must be the point in Fixed, that is to say, the flow is the flow of Fixed->Moving, so it is reasonable to insert the value of Moving.
Add a picture as an explanation:
insert image description here


  1. Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration ↩︎

  2. Fast Symmetric Diffeomorphic Image Registration with Convolutional Neural Networks ↩︎

  3. Inverse-Consistent Deep Networks for Unsupervised Deformable Image Registration ↩︎

Guess you like

Origin blog.csdn.net/xiufan1/article/details/127006449