Pytorch——基础API用法总结

Pytorch Documentation

Tensor维度等变换操作

不同Tensor之间的合并与分离

PyTorch常用张量切割和拼接方法

x = torch.randn(B, C, W, H)
y = torch.cat([x, x],dim=-1)
y.shape # [B, C, W, 2H]
x = torch.randn(B, C, W, H)
y = torch.stack([x, x],dim=0)
y.shape # [2, B, C, W, 2H]
x = torch.randn(B, C, 400, 200)
out = torch.split(x, split_size_or_sections=[200,200], dim=-2)
out # 是一个tuple,存储了分割后的tensor
out[0].shape # [B, C, 200, 200]
out[1].shape # [B, C, 200, 200]

单个Tensor的维度变换

B, N, C, H, W = feat.shape
feat = feat.view(B, N, C, H*W)
B, N, C, H, W = feat.shape
feat = feat.permute(1, 0, 3, 4, 2)
feat.shape # [N, B, H, W, C]
  • Tensor.unsqueeze()(扩展维度)
C, H, W = feat.shape
feat = feat.unsqueeze(0) # 在第0维扩展一个维度
feat.shape # [1, C, H, W]
  • Tensor.squeeze()(删除一个空维度)
1, C, H, W = feat.shape
feat = feat.squeeze(0) # 在第0维删除一个空维度
feat.shape # [C, H, W]
  • Tensor.repeat()(在某一维进行重复扩充)
x = torch.randn(1, 3, 224, 224)
y = x.repeat(3, 1, 1, 1)
y.shape # [3, 3, 224, 224]

Tensor内存优化

  • Tensor.contiguous()
    • 功能:Tensor.contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,并在内存空间上进行对齐,即在内存空间上,tensor元素的内存地址保持连续。
    • 意义: 这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化。

Tensor计算

Tensor 的归并运算

Tensor 的归并运算(torch.mean、sum、median、mode、norm、dist、std、var、cumsum、cumprod)

B, N, C, H, W = feat.shape
feat = feat.max(dim=1)
feat.shape # [B, 1, C, H, W]
  • torch.sum()
  • torch.mean()
  • torch.median()
  • torch.mode()
  • torch.norm()
  • torch.dist()
  • torch.std()
  • torch.var()
  • torch.cumsum()
  • torch.cumprod()

Tensor平移和旋转

Pytorch中的仿射变换(affine_grid)

  • torch.nn.functional.affine_grid()
  • torch.nn.functional.grid_sample()

猜你喜欢

转载自blog.csdn.net/qq_45779334/article/details/124557760
今日推荐