expand() function in pytorch
The function of the expand function is to expand the size of a certain dimension of data in the tensor. It returns the tensor after the input tensor is expanded to a larger size in a certain dimension.
For example:
x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
y1 = x.expand(3,3)
print(x.size())
print(x)
print(y)
print(y1)
Output:
torch.Size([3])
tensor([1, 2, 3])
tensor([[1, 2, 3],
[1, 2, 3]])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
From this we can see that the expand dimension will copy the original data for expansion. We verify it with the following example.
x = torch.tensor([[1], [2], [3]])
y = x.expand(3, 3)
y1 = x.expand(4,3)
print(x.size())
print(y)
print(y1)
Output:
torch.Size([3, 1])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])