Essential understanding of einsum

einsum is a fast and flexible matrix operation function in numpy and pytorch. I read many articles on the Internet to explain this operation, but I didn’t explain its essence. So many situations can’t be memorized. I finally saw an essence today. Explain, and record it for later use. Original address (it doesn’t seem to be original):
https://www.jianshu.com/p/27350d110caf
is essentially the following picture, and the einsum operator is the abbreviation of the key parameters of the following nested loop:
insert image description here
specific explanation:
(1) The ' -> ' of the einsum operator is the input on the left and the output on the right. The input can be 1 to N, and each input variable is separated by ',', and the output can only be one. The letter of each variable indicates the corresponding dimension. The number of letters must be consistent with the number of dimensions, otherwise an error will be reported.
(2) First, determine the nesting of the outermost for loop according to the output dimension. For example, the above output C has two dimensions i, j, which is two layers of loop nesting.
(3) Then the dimension mark letter that disappears on the output side indicates that addition aggregation has occurred, so it needs to be looped in the inner layer.
(4) The bottom layer of the loop is the multiplication of the corresponding elements of the corresponding input variables, such as the above A[i,k] * B[k,j] (5) What needs to be added is when ' -> ' is
omitted Indicates: When outputting the matrix, remove the repeated letters according to all the input letters and arrange them in alphabetical order. For example, 'ik,kj' means 'ik,kj -> ij'. It can be seen that this operator really compresses the characters to the extreme, and if it can be saved, it can be saved. Others, I don’t think it is necessary to increase the difficulty of understanding in order to save these few characters.

The above explanation can help us understand any einsum operator, but in specific applications, we need to master the construction of operators that meet the requirements. I think it can be constructed in the following steps: (1) First write the dimension of the output variable to the right of '->
' ,
(2) Then observe which dimensions of the input variables have undergone additive aggregation, and the axis where the additive aggregation occurs should be represented by a letter that the output does not have
(3) Then observe which dimension axes between the input variables need to be synchronized, usually with the same Axis of physical meaning, they are represented by the same letter,
(4) Check again at last.

Guess you like

Origin blog.csdn.net/Brikie/article/details/125661022