PyTorch in scatter and gather usage

PyTorch in scatter and gather usage

Cackle

Long time no update blog, the whole 2019 was abandoned, did not make anything, ready to begin spring training next year, though not looking for work algorithm Kong, but still ready to seriously sort out the last half of 2019 learning about their depth of knowledge of machine learning and learning.

scatter usage

scatter scatter translated into Chinese, first of all look at an example to intuitively feel this API function, using the example of pytorch official website provides.

import torch 
import torch.nn as nn
x = torch.rand(2,5)
x
tensor([[0.2656, 0.5364, 0.8568, 0.5845, 0.2289],
        [0.0010, 0.8101, 0.5491, 0.6514, 0.7295]])
y = torch.zeros(3,5)
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
index
tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])
y.scatter_(dim=0,index=index,src=x)
y
tensor([[0.2656, 0.8101, 0.5491, 0.5845, 0.2289],
        [0.0000, 0.5364, 0.0000, 0.6514, 0.0000],
        [0.0010, 0.0000, 0.8568, 0.0000, 0.7295]])

First, we can see that all of the x values appear in the y, and is indexed axis dim = 0, x is an arbitrary element from the complete mapping of the following formula.
y [index [i, j] , j] = x [i, j], for x [0,1] = 0.5364, index [0,1] = 1 indicates that this value will appear in the first dimension of y dim = 0 , at index position 1, and therefore, y [index [0,1], 1] = y [1,1] = x [0,1] = 0.5364.

Here we have to Scatter, i.e., the scattering function with an intuitive understanding, may be used to index into a matrix mapping a matrix, mapped to specified axis dim, index specified to map the shaft, so for 3D tensor, if a call y.scatter_ (dim, index, src), then there is:

y[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
y[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
y[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

Finally, a look at the official documentation about the scatter of English description:

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

And means intuitive feel almost the same function may be mapped to a target src tensor self, on the dimension dim, the index is given by the subscript index, on a non-dim dimension, directly indexing position where the value in the src.

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Obviously self, index, src of ndim should be the same, otherwise Subscript out of range, the formula from the point of view index.size (d)> src.size (d), index.size (d)> self.size (d) Nothing issue, index the array can be larger than src, the guess here is to consider the project, because the index exceeded the size of the array src here is useless, unused space will not be accessed.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

All values ​​in the index needs [0, self.size (dim) - 1] interval, which must be met, otherwise it crossed the line. The second sentence says all along the value index of dim dimension needs to be unique, the results of my test findings can be repeated, look at the following code:

x = torch.rand(2,5)
x
tensor([[0.6542, 0.6071, 0.7546, 0.4880, 0.1077],
        [0.9535, 0.0992, 0.0594, 0.0641, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.6542, 0.0992, 0.0594, 0.4880, 0.1077],
        [0.0000, 0.6071, 0.0000, 0.0641, 0.0000],
        [0.9535, 0.0000, 0.7546, 0.0000, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[0,1,2,0,0]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.9535, 0.0000, 0.0000, 0.0641, 0.7563],
        [0.0000, 0.0992, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0594, 0.0000, 0.0000]])

Dim = 0 can be seen along the axis was repeated five times, namely (0,0), (1,1), (2,2), (0,0), (0,0), no code is given and warning, just overwrite the original value, the document may not be updated, but the updated API.

params:

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

It is worth noting that the value parameter, when not specified src, you can specify a floating-point value variable, which we use to achieve a scatter version of onehot function.

x = torch.tensor([[1,1,1,1,1]],dtype=torch.float32)
index = torch.tensor([[0,1,2,3,4]],dtype=torch.int64)
y = torch.zeros(5,5,dtype=torch.float32)
x
tensor([[1., 1., 1., 1., 1.]])
y.scatter_(0,index,x)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
y = torch.zeros(5,5,dtype=torch.float32)
y.scatter_(0,index,1)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

We can see the assigned value = 1, and src = [[1,1,1,1,1]] equivalents. Here scatter over.

gather usage

scatter gather is the reverse process of collecting data from a tensor tensor to another, see an example have an intuitive feel.

x = torch.tensor([[1,2],[3,4]])
torch.gather(input=x,dim=1,index=torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
        [4, 3]])

You can guess the collection process, in accordance with the index and dim the selected data in x out, placed in y, satisfies the following formula:
y [I, J] = x [I, index [I, J]], therefore y [0,0] = x [0, index [0,0]] = x [0,0] = 1, y [1,0] = x [1, index [1,0]] = x [1, 1] = 4, for 3D data, satisfies the following formula:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Here's to gather usage presentation is over, because, after all, is the reverse of scatter gather, understanding the scatter, gather does not need much explanation.

summary

  1. scatter may be mapped to a tensor another tensor, wherein the application is a function onehot.
  2. gather and scatter are two reciprocal process, gather it can be used to compress the sparse tensor, tensor elements of a sparse collection of Central Africa 0.
  3. Do not neglect the time, can not make the results not entirely to blame their own.

Guess you like

Origin www.cnblogs.com/liuzhan709/p/11875743.html