深入浅出Pytorch函数——torch.squeeze

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.squeeze
· 深入浅出Pytorch函数——torch.unsqueeze


将输入张量形状为1的维度去除并返回。比如输入向量的形状为 A × 1 × B × 1 × C × 1 × D A\times1\times B\times1\times C\times1\times D A×1×B×1×C×1×D,则输出向量形状就为 A × B × C × D A\times B\times C\times D A×B×C×D。当给定参数dim时,则操作只在给定维度dim上。例如,输入向量的形状为 A × 1 × B A\times1\times B A×1×B,使用squeeze(input, 0),输出向量的形状将会保持张量不变,只有使用 squeeze(input, 1),输出向量的形状才会变成 A × B A\times B A×B。需要注意的是,返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

语法

torch.squeeze(input, dim=None) → Tensor

参数

  • input:[Tensor] 输入张量
  • dim:[可选,int/tuple] 挤压维度的位置索引

实例

输入:

x = torch.zeros(2, 1, 2, 1, 2)
x.size()

输出:

torch.Size([2, 1, 2, 1, 2])

输入:

y = torch.squeeze(x)
y.size()

输出:

torch.Size([2, 2,, 2])

输入:

y = torch.squeeze(x, 0)
y.size()

输出:

torch.Size([2, 1, 2, 1, 2])

输入:

y = torch.squeeze(x, 1)
y.size()

输出:

torch.Size([2, 2, 1, 2])

输入:

y = torch.squeeze(x, (1, 2, 3))
y.size()

输出:

torch.Size([2, 2, 2])

猜你喜欢

转载自blog.csdn.net/hy592070616/article/details/131872582