every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog
0. Preface
Summarize the commonly used functions of torch.
Including the following:
torch.FloatTensor torch.max torch.numel torch.sort torch.sort torch.clamp torch.nonzero torch.cat torch.statck
Description: In the follow-up continuous update...
1. Text
1.1 torch.FloatTensor()
- When the input is in the form of an array,
the data type inside will be converted to a floating point type
a = torch.FloatTensor([1,2,3,4])
2. When the input is a number
, generate an array of the corresponding dimension of the number (the value in the array is 0)
b = torch.FloatTensor(4)
Note: The numbers generated by different versions of torch may be different, and it is possible to generate a large number.
1.2 torch.max()
- When only Tensor is passed in, return the largest number inside
a = torch.Tensor([5,6,9])
torch.max(a)
2. When passing in Tensor and dimension (in which dimension to compare the size),
return the largest number and index
a = torch.Tensor([[5,6,9],
[2,3,1]])
torch.max(a,dim=1)
As shown in the figure below, the returned value is 9 and 3. Of course, this is compared in the dimension of dim=1, that is, the comparison in each row selects the largest one.
The data in the first row is [5, 6, 9], the largest is 9, and the index is 2 The
data in the second row is [2, 3, 1], the largest is 3, and the index is 1
to split the number and index ,As shown below:
1.3 torch.numel()
Return how many numbers are in the tensor
a = torch.Tensor([2,3,4,5,1,2,3])
a.numel()
a = torch.Tensor([[5,6,9],[2,3,1]])
a.numel()
1.4 torch.sort(0, descending=True)
Sort the specified dimension, True means descending order, False means ascending order
Returns the sorted tensor and the index corresponding to the number in the original tensor
a = torch.Tensor([2,3,4,5,1,2,3])
a.sort(0,True)
Similar to max above, it can be split, as shown in the figure below
1. 5 torch.clamp()
Modify the tensor value
- When the value is less than min, the value is changed to min
- When the value is greater than max, the value is changed to max
a = torch.Tensor([2,3,4,5])
a.clamp(min=3)
a = torch.Tensor([2,3,4,5])
a.clamp(max=3)
1.6 torch.nonzero()
Returns the index of the non-zero element
One-dimensional:
a = torch.Tensor([0, 1, 2, 3, 0, 5])
a.nonzero()
Returns the index of the non-zero element in the tensor:
two-dimensional:
a = torch.Tensor([
[0, 1, 2, 3, 0, 5],
[0, 0, 0, 0, 0, 10] ])
a.nonzero()
The index of the non-zero element:
eg, a[0,1], a[0,2] ... etc., its value is non-zero.
1.7 torch.cat()
Merge along the specified dimension , the default dimension dim=0
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
The shape of each tensor:
1.8 torch.stack()
Merge along the specified dimension, but there will be an extra dimension
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))