2d版本:
二、pytorch中的仿射变换
pytorch中就使用的为后向变换。主要涉及两个函数
- F.affine_grid(theta,size)
- F.grid_sample(input, grid, mode=’bilinear’, padding_mode=’zeros’)
1.F.affine_grid根据输入的变换矩阵theta和尺寸利用后向变换求出目标图像每个像素在原图像的位置。
theta是一个\[N,2,3\]的tensor,N为batchsize大小;2行3列共六个参数,为affine的变换矩阵,第一行为x坐标,即横坐标的变换参数,前两个为权重,最后一个为偏移,值得注意的是偏移值是一个相对于图像宽归一化的参数a,c,e(并非像素值),例如0.5表示左移半个图像的宽度。第二行表示y坐标的变换参数(b,d,f)。
size是一个tuple,为(N,C,H,W)
output为[N,h,w,2]的Tensor,表示在原图中的对应位置。
- F.grid_sample()为重采样函数,根据输入的原图和位置对应关系矩阵(F.affine_grid的输出)对原图像素进行重采样,构成变换后的图像。由于重采样过程中,在原图中的位置会出现小数,因此需要对原图进行插值,插值方式为可选参数,默认双线性插值。
下面我们来看一个例子:
将图像顺时针旋转45度,注意pytorch使用的为后向变换。
对于前向变换来说,顺时针旋转45度的变换矩阵为,后向变换应该对其求逆。但是我们可以换一个角度理解,原图到目标图需要顺时针旋转45度,那么目标图到原图不就是逆时针旋转45度吗,因此直接取带入原公式计算即可
代码如下:
import torch
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt
theta = torch.Tensor([[0.707,0.707,0],[-0.707,0.707,0]]).unsqueeze(dim=0)
img = cv2.imread('achor.png',cv2.IMREAD_GRAYSCALE)
plt.subplot(2,1,1)
plt.imshow(img,cmap='gray')
plt.axis('off')
img = torch.Tensor(img).unsqueeze(0).unsqueeze(0)
grid = F.affine_grid(theta,size=img.shape)
output = F.grid_sample(img,grid)[0].numpy().transpose(1,2,0).squeeze()
plt.subplot(2,1,2)
plt.imshow(output,cmap='gray')
plt.axis('off')
plt.show()
结果如下(pytorch中以图像中心点为原点,与一般的左上角为原点不太一样):
pytorch理论上支持,也确实提供了API,但是经过我的研究发现,pytorch3D的仿射变换目前存在bug,bug主要体现在平移操作上,旋转和缩放暂时看起来还算正常
3D的仿射变换
2D的仿射变换我就不细说了在pytorch中的操作可以参考这个
Pytorch中的仿射变换(affine_grid)www.jianshu.com/p/723af68beb2e正在上传…重新上传取消
其实在老版本的pytorch里面似乎是不支持3D的仿射变换的,但是1.6以后的pytorch应该都是支持的。
平移变换(translate)
平移变换的转化矩阵就是这样的,a,b,c分别就是沿x, y, z轴平移多少。
在pytorch中实现沿x轴平移10个像素的操作如下
import torch
import torch.nn.functional as F
#3D图像输入
img = torch.randn(1, 1, 160, 192, 160) #[batch_size, channel, D, H, W]
#3D仿射变换矩阵
theta = torch.tensor([[1, 0, 0, 10],
[0, 1, 0, 0],
[0, 0, 1, 0]], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img.size())
output = F.grid_sample(img, grid)
这里需要注意一下,pytorch输入的变换矩阵是 的,而不是 , 变换矩阵是齐次坐标的表示,对pytorch来说不需要,所以把最后一行去掉就行了。
旋转变换(rotate)
旋转比较复杂,分为沿x轴旋转,y轴旋转,和z轴旋转,每个旋转的变换矩阵都不一样。
首先是沿x轴旋转
然后是沿y轴旋转
最后是沿z轴旋转
在pytorch中实现沿x轴旋转1度如下:
import torch
import torch.nn.functional as F
import math
#3D图像输入
img = torch.randn(1, 1, 160, 192, 160) #[batch_size, channel, D, H, W]
#3D仿射变换矩阵
theta = torch.tensor([[1, 0, 0, 0],
[0, math.cos(1), -math.sin(1), 0],
[0, math.sin(1), math.cos(1), 0]], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img.size())
output = F.grid_sample(img, grid)
缩放变换(scale)
缩放操作比较简单
意思就是把每个轴的坐标放大缩小至a, b, c倍。
比如对x轴坐标放大至1.1倍,其他轴不变的pytorch实现就是:
import torch
import torch.nn.functional as F
#3D图像输入
img = torch.randn(1, 1, 160, 192, 160) #[batch_size, channel, D, H, W]
#3D仿射变换矩阵
theta = torch.tensor([[1.1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0]], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img.size())
output = F.grid_sample(img, grid)
原文:pytorch的3D仿射变换 - 知乎 (zhihu.com)