1、torch.unbind()
说明:移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。
参数:
- tensor(Tensor) -- 输入张量
- dim(int) -- 删除的维度
In [6]: x
Out[6]:
tensor([[ 1.2474, 0.1820, -0.0179],
[ 0.1388, -1.7373, 0.5934],
[ 0.2288, 1.1102, 0.6743]])
In [7]: torch.unbind(x, 0)
Out[7]:
(tensor([ 1.2474, 0.1820, -0.0179]),
tensor([ 0.1388, -1.7373, 0.5934]),
tensor([0.2288, 1.1102, 0.6743]))
In [8]: torch.unbind(x, 1)
Out[8]:
(tensor([1.2474, 0.1388, 0.2288]),
tensor([ 0.1820, -1.7373, 1.1102]),
tensor([-0.0179, 0.5934, 0.6743]))
2、