expand() function in pytorch

expand() function in pytorch

The function of the expand function is to expand the size of a certain dimension of data in the tensor. It returns the tensor after the input tensor is expanded to a larger size in a certain dimension.
For example:

x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
y1 = x.expand(3,3)
print(x.size())
print(x)
print(y)
print(y1)

Output:

torch.Size([3])
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

From this we can see that the expand dimension will copy the original data for expansion. We verify it with the following example.

x = torch.tensor([[1], [2], [3]])
y = x.expand(3, 3)
y1 = x.expand(4,3)
print(x.size())
print(y)
print(y1)

Output:

torch.Size([3, 1])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

Guess you like

Origin blog.csdn.net/weixin_46088099/article/details/125502426