pytorch+python(三)

1、torch.unbind()

说明:移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。

参数

  • tensor(Tensor) -- 输入张量
  • dim(int) -- 删除的维度
In [6]: x
Out[6]:
tensor([[ 1.2474,  0.1820, -0.0179],
        [ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [7]: torch.unbind(x, 0)
Out[7]:
(tensor([ 1.2474,  0.1820, -0.0179]),
 tensor([ 0.1388, -1.7373,  0.5934]),
 tensor([0.2288, 1.1102, 0.6743]))

In [8]: torch.unbind(x, 1)
Out[8]:
(tensor([1.2474, 0.1388, 0.2288]),
 tensor([ 0.1820, -1.7373,  1.1102]),
 tensor([-0.0179,  0.5934,  0.6743]))

2、

猜你喜欢

转载自blog.csdn.net/Vpn_zc/article/details/113183830
今日推荐