[torch] Constant function summary, torch.FloatTensor torch.max torch.numel torch.sort torch.sort torch.clamp torch.nonzero

every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog

0. Preface

Summarize the commonly used functions of torch.
Including the following:
torch.FloatTensor torch.max torch.numel torch.sort torch.sort torch.clamp torch.nonzero torch.cat torch.statck

Description: In the follow-up continuous update...

1. Text

1.1 torch.FloatTensor()

  1. When the input is in the form of an array,
    the data type inside will be converted to a floating point type
a = torch.FloatTensor([1,2,3,4])

insert image description here
2. When the input is a number
, generate an array of the corresponding dimension of the number (the value in the array is 0)

b = torch.FloatTensor(4)

insert image description here
Note: The numbers generated by different versions of torch may be different, and it is possible to generate a large number.

1.2 torch.max()

  1. When only Tensor is passed in, return the largest number inside
a = torch.Tensor([5,6,9])
torch.max(a)

insert image description here
2. When passing in Tensor and dimension (in which dimension to compare the size),
return the largest number and index

a = torch.Tensor([[5,6,9],
				[2,3,1]])
torch.max(a,dim=1)

As shown in the figure below, the returned value is 9 and 3. Of course, this is compared in the dimension of dim=1, that is, the comparison in each row selects the largest one.
The data in the first row is [5, 6, 9], the largest is 9, and the index is 2 The
data in the second row is [2, 3, 1], the largest is 3, and the index is 1
insert image description here
to split the number and index ,As shown below:
insert image description here

1.3 torch.numel()

Return how many numbers are in the tensor

a = torch.Tensor([2,3,4,5,1,2,3])
a.numel()

insert image description here

a = torch.Tensor([[5,6,9],[2,3,1]])
a.numel()

insert image description here

1.4 torch.sort(0, descending=True)

Sort the specified dimension, True means descending order, False means ascending order

Returns the sorted tensor and the index corresponding to the number in the original tensor

a = torch.Tensor([2,3,4,5,1,2,3])
a.sort(0,True)

insert image description here
Similar to max above, it can be split, as shown in the figure below
insert image description here

1. 5 torch.clamp()

Modify the tensor value

  • When the value is less than min, the value is changed to min
  • When the value is greater than max, the value is changed to max
a = torch.Tensor([2,3,4,5])
a.clamp(min=3)

insert image description here

a = torch.Tensor([2,3,4,5])
a.clamp(max=3)

insert image description here

1.6 torch.nonzero()

Returns the index of the non-zero element

One-dimensional:

a = torch.Tensor([0, 1, 2, 3, 0, 5])
a.nonzero()

Returns the index of the non-zero element in the tensor:
insert image description here
two-dimensional:

a = torch.Tensor([
                    [0, 1, 2, 3, 0, 5],
                    [0, 0, 0, 0, 0, 10] ])
a.nonzero()

The index of the non-zero element:
eg, a[0,1], a[0,2] ... etc., its value is non-zero.
insert image description here

1.7 torch.cat()

Merge along the specified dimension , the default dimension dim=0

a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))

The shape of each tensor:
insert image description here
insert image description here

1.8 torch.stack()

Merge along the specified dimension, but there will be an extra dimension

a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))

insert image description here
insert image description here

Guess you like

Origin blog.csdn.net/weixin_39190382/article/details/130323677