【基于PyTorch】torch.unsqueeze() 使用

  • 输入:
import torch

x = torch.tensor([1, 2, 3, 4])
print(torch.unsqueeze(x, 0))
print("**************")
print(torch.unsqueeze(x, 1))
  • 输出:
tensor([[1, 2, 3, 4]])
**************
tensor([[1],
        [2],
        [3],
        [4]])

猜你喜欢

转载自blog.csdn.net/weixin_42521185/article/details/124668530