Tensor维度等变换操作
不同Tensor之间的合并与分离
Tensor.cat()
(指定维度进行合并)
x = torch.randn(B, C, W, H)
y = torch.cat([x, x],dim=-1)
y.shape # [B, C, W, 2H]
Tensor.stack()
(指定并创建一个新维度进行合并)
x = torch.randn(B, C, W, H)
y = torch.stack([x, x],dim=0)
y.shape # [2, B, C, W, 2H]
Tensor.split()
(指定并创建一个新维度进行合并)
x = torch.randn(B, C, 400, 200)
out = torch.split(x, split_size_or_sections=[200,200], dim=-2)
out # 是一个tuple,存储了分割后的tensor
out[0].shape # [B, C, 200, 200]
out[1].shape # [B, C, 200, 200]
单个Tensor的维度变换
Tensor.view()
(维度合并)
B, N, C, H, W = feat.shape
feat = feat.view(B, N, C, H*W)
Tensor.permute()
(维度交换)
B, N, C, H, W = feat.shape
feat = feat.permute(1, 0, 3, 4, 2)
feat.shape # [N, B, H, W, C]
Tensor.unsqueeze()
(扩展维度)
C, H, W = feat.shape
feat = feat.unsqueeze(0) # 在第0维扩展一个维度
feat.shape # [1, C, H, W]
Tensor.squeeze()
(删除一个空维度)
1, C, H, W = feat.shape
feat = feat.squeeze(0) # 在第0维删除一个空维度
feat.shape # [C, H, W]
Tensor.repeat()
(在某一维进行重复扩充)
x = torch.randn(1, 3, 224, 224)
y = x.repeat(3, 1, 1, 1)
y.shape # [3, 3, 224, 224]
Tensor内存优化
Tensor.contiguous()
- 功能:
Tensor.contiguous()
函数不会对原始数据进行任何修改,而仅仅对其进行复制,并在内存空间上进行对齐,即在内存空间上,tensor元素的内存地址保持连续。 - 意义: 这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化。
- 功能:
Tensor计算
Tensor 的归并运算
Tensor 的归并运算(torch.mean、sum、median、mode、norm、dist、std、var、cumsum、cumprod)
Tensor.max()
(指定某一维取最大值并合并)
B, N, C, H, W = feat.shape
feat = feat.max(dim=1)
feat.shape # [B, 1, C, H, W]
torch.sum()
torch.mean()
torch.median()
torch.mode()
torch.norm()
torch.dist()
torch.std()
torch.var()
torch.cumsum()
torch.cumprod()
Tensor平移和旋转
torch.nn.functional.affine_grid()
torch.nn.functional.grid_sample()