pytorch expand

basic usage

The following is generated by chatgpt

PyTorch's expandmethods are used to expand tensor dimensions. It is a method of Tensor that can be used to copy the shape of a tensor, copying a small tensor into a larger tensor.

The most common usage scenario is to expand certain dimensions to fit the shape of other tensors. For example, suppose you have an N-dimensional tensor A, and you want to operate on it on a B-dimensional tensor, but B has a larger dimension than A. In this case, expandmethods can be used to extend the dimensions of A to fit the dimensions of B.

Specifically, the usage of this method is as follows:

expanded_tensor = tensor.expand(sizes)

Among them, tensor is the tensor that needs to be expanded, and sizes is a list of integers, indicating the expanded shape. For example, if tensor has shape (3, 4) and sizes is (3, 4, 5), then the expanded tensor will have shape (3, 4, 5).

Simple example:

import torch

a = torch.tensor([1, 2, 3])
b = a.expand(3, 3)
print(b)

# output:
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

In the above example, we expand the 1-dimensional tensor a to a 2-dimensional tensor b by using the expand method, and specify the shape of the expanded tensor as (3, 3) by specifying the sizes parameter.

The case where sizes contain -1

If -1 is included in the parameter list sizes, PyTorch will try to match the size of the dimension to make it equal to the original dimension. This will only expand the dimension for which the size is specified.

Simple example:

import torch

a = torch.tensor([1, 2, 3, 4])
b = a.expand(2, -1)
print(b)

# output:
# tensor([[1, 2, 3, 4],
#         [1, 2, 3, 4]])

In the above example, we expanded the shape of tensor a to (2, 4). Tensor a is extended twice along dimension 0.

Guess you like

Origin blog.csdn.net/duoyasong5907/article/details/128991391