pytorch 中 expand ()函数
expand函数的功能就是 用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量。
例如:
x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
y1 = x.expand(3,3)
print(x.size())
print(x)
print(y)
print(y1)
输出:
torch.Size([3])
tensor([1, 2, 3])
tensor([[1, 2, 3],
[1, 2, 3]])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
由此可以看到expand扩展维度会复制原有的数据进行扩展,我们以下面的例子进行验证。
x = torch.tensor([[1], [2], [3]])
y = x.expand(3, 3)
y1 = x.expand(4,3)
print(x.size())
print(y)
print(y1)
输出:
torch.Size([3, 1])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])