Visualization of deformation field in medical image registration (drawing deformation field)

This article describes the visualization of the deformation field in medical image registration, including one method of viewing directly through tools and two methods of manual drawing.

First, let me introduce the deformation field. The deformation field of a two-dimensional image with a size of [W, H] is [W, H, 2], and the size of the third dimension is 2, which are expressed on the x-axis and Displacement in the y-axis direction. In the same way, the size of the deformation field corresponding to a three-dimensional image with a size of [D,W,H] is [D,W,H,3], and the size of the third dimension is 3, which are expressed on the x-axis and y-axis respectively. The displacement in the direction of the axis and the z axis. The figure below is a deformation field obtained after registration of a two-dimensional brain image.

Insert picture description here

As mentioned at the beginning, the following introduces 3 ways to visualize the deformation field.

1. Use ITK-Snap to visualize the deformation field

This method is the most convenient and in my opinion the best effect, but it is only suitable for three-dimensional deformation fields. First save the deformation field as a ".nii" image, open it with the ITK-Snap tool, first hover for a while at the position of the red frame in the figure below, click after the icon in the figure below appears, and select "under Multi-Component Display" Grid", you can see the effect shown in the figure below.
Insert picture description here

Insert picture description here

2. Maomao's method

For specific ideas, please see Maomao Great God’s article plt.contour Drawing Image Deformation Field (Deformation Field) . In fact, I didn't fully understand the specific idea. It is about first generating a regular grid, and then displaying it after adding a deformation field. But if the grid is [W,H] size, but the deformation field is [W,H,2] size, so you can only choose one dimension to add, which feels a bit strange. code show as below:

import matplotlib.pyplot as plt
import SimpleITK as sitk
import numpy as np


def grid2contour(grid, title):
    '''
    grid--image_grid used to show deform field
    type: numpy ndarray, shape: (h, w, 2), value range:(-1, 1)
    '''
    assert grid.ndim == 3
    x = np.arange(-1, 1, 2.0 / grid.shape[1])
    y = np.arange(-1, 1, 2.0 / grid.shape[0])
    X, Y = np.meshgrid(x, y)
    Z1 = grid[:, :, 0] + 2  # remove the dashed line
    Z1 = Z1[::-1]  # vertical flip
    Z2 = grid[:, :, 1] + 2

    plt.figure()
    plt.contour(X, Y, Z1, 15, levels=50, colors='k') #改变levels的值,可以改变形变场的密集程度
    plt.contour(X, Y, Z2, 15, levels=50, colors='k')
    plt.xticks(()), plt.yticks(())  # remove x, y ticks
    plt.title(title)
    plt.show()


def show_grid():
    img = sitk.ReadImage("./2D1.nii")
    img_arr = sitk.GetArrayFromImage(img)
    img_shape = img_arr.shape

    # 起点、终点、步长(可为小数)
    x = np.arange(-1, 1, 2 / img_shape[1])
    y = np.arange(-1, 1, 2 / img_shape[0])
    X, Y = np.meshgrid(x, y)
    regular_grid = np.stack((X, Y), axis=2)
    grid2contour(regular_grid, "regular_grid")

    rand_field = np.random.rand(*img_shape[:2], 2)  # 参数前加*是以元组形式导入
    rand_field_norm = rand_field.copy()
    rand_field_norm[:, :, 0] = rand_field_norm[:, :, 0] * 2 / img_shape[1]
    rand_field_norm[:, :, 1] = rand_field_norm[:, :, 1] * 2 / img_shape[0]
    sampling_grid = regular_grid + rand_field_norm
    grid2contour(sampling_grid, "sampling_grid")

    img_arr[..., 0] = img_arr[..., 0] * 2 / img_shape[1]
    img_arr[..., 1] = img_arr[..., 1] * 2 / img_shape[0]
    img_grid = regular_grid + img_arr
    grid2contour(img_grid, "img_grid")


if __name__ == "__main__":
    show_grid()
    print("end")

In the 20th line of the code, changing the value of levels can change the density of the grid in the deformation field.

3. Generate a regular grid first, and then use a spatial transformation network to deform

This method was proposed by a friend in the registration group and implemented it by himself. The idea is to first generate a regular grid and save it as a picture, and then use a spatial transformation network to deform it. code show as below:

import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F


# 空间转换网络
class SpatialTransformer(nn.Module):
    # 1.生成网格grid;2.new_grid=grid+flow,即旧网格加上一个位移;3.将网格规范化到[-1,1];4.根据新网格对原图进行采样
    def __init__(self, size, mode='bilinear'):
        """
        Instiatiate the block
            :param size: size of input to the spatial transformer block
            :param mode: method of interpolation for grid_sampler
        """
        super(SpatialTransformer, self).__init__()

        # Create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)  # y, x, z
        grid = torch.unsqueeze(grid, 0)  # add batch
        grid = grid.type(torch.FloatTensor)
        self.register_buffer('grid', grid)

        self.mode = mode

    def forward(self, src, flow):
        """
        Push the src and flow through the spatial transform block
            :param src: the original moving image
            :param flow: the output from the U-Net
        """
        new_locs = self.grid + flow

        shape = flow.shape[2:]

        # Need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)  # 维度置换,变为0,2,3,1
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, mode=self.mode)


# 生成网格图片
def create_grid(size, path):
    num1, num2 = (size[0] + 10) // 10, (size[1] + 10) // 10  # 改变除数(10),即可改变网格的密度
    x, y = np.meshgrid(np.linspace(-2, 2, num1), np.linspace(-2, 2, num2))

    plt.figure(figsize=((size[0] + 10) / 100.0, (size[1] + 10) / 100.0))  # 指定图像大小
    plt.plot(x, y, color="black")
    plt.plot(x.transpose(), y.transpose(), color="black")
    plt.axis('off')  # 不显示坐标轴
    # 去除白色边框
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.savefig(path)  # 保存图像
    # plt.show()


if __name__ == "__main__":
    out_path = r"C:\Users\zuzhiang\Desktop\new_img.jpg"  # 图片保存路径
    # 读入形变场
    phi = sitk.ReadImage("./2D1.nii")  # [324,303,2]
    phi_arr = torch.from_numpy(sitk.GetArrayFromImage(phi)).float()
    phi_shape = phi_arr.shape
    # 产生网格图片
    create_grid(phi_shape, out_path)
    img = sitk.GetArrayFromImage(sitk.ReadImage(out_path))[..., 0]
    img = np.squeeze(img)[np.newaxis, np.newaxis, :phi_shape[0], :phi_shape[1]]
    # 用STN根据形变场对网格图片进行变形
    STN = SpatialTransformer(phi_shape[:2])
    phi_arr = phi_arr.permute(2, 0, 1)[np.newaxis, ...]
    warp = STN(torch.from_numpy(img).float(), phi_arr)
    # 保存图片
    warp_img = sitk.GetImageFromArray(warp[0, 0, ...].numpy().astype(np.uint8))
    sitk.WriteImage(warp_img, out_path)
    print("end")

It should be mentioned that because the size of the grid image cannot be accurately determined, an additional 10 pixels are generated when it is generated, and then cropped when it is read. Changing the divisor (10) in line 57 of the above code can change the density of the grid lines.

I feel that this method is a little more elegant, but the effect is similar to the second method. The effects of the second and third methods are shown in the following figure:

Guess you like

Origin blog.csdn.net/zuzhiang/article/details/107423465