pytorch 中 expand ()函数

pytorch 中 expand ()函数

expand函数的功能就是 用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量。
例如:

x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
y1 = x.expand(3,3)
print(x.size())
print(x)
print(y)
print(y1)

输出:

torch.Size([3])
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

由此可以看到expand扩展维度会复制原有的数据进行扩展,我们以下面的例子进行验证。

x = torch.tensor([[1], [2], [3]])
y = x.expand(3, 3)
y1 = x.expand(4,3)
print(x.size())
print(y)
print(y1)

输出:

torch.Size([3, 1])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

猜你喜欢

转载自blog.csdn.net/weixin_46088099/article/details/125502426