Several common scenarios for transpose operator optimization

It is difficult to write a kernel that can be optimal in all scenarios of transpose at the same time. Several common scenarios of transpose can be summarized and optimized in a targeted manner. Here are only a few cases of transpose's transformation of the axis, without considering the size of the shape. Therefore, in these scenarios, the targeted optimization of the transposed axis shape size should also be considered.

Scene 1: batch 2D, perm: 021

For 2D, 3D, or higher-dimensional tensors, swap the innermost two dimensions. These can all be unified into batch 2D. The transposition of a 2D matrix is ​​equivalent to batch=1, and a tensor larger than three dimensions can combine all dimensions other than the two innermost dimensions into one dimension.

A method of transposing a large 021 matrix into a small matrix:

For example, the input is a large matrix of [4096, 4096] for transposition, but each thread reads the adjacent position of the same row of input, but writes the adjacent position of the same column when writing, and for writing, the same column spans If the line length is too long, it is not friendly to the write-back cache hit. You can consider reducing the line length of the span to optimize performance.

Consider splitting a large transposition into several smaller ones:

For example, transpose MN to NM, considering that both M and N can be split into smaller dimensions: M is split into M1M0, and N is split into N1N0

Then the problem has MN transposed to NM becomes M1M0N1N0 transposed to N1N0M1M0

Assuming that only two axes can be exchanged at a time, it can be achieved through the following three steps:

M1M0N1N0->N1M1M0N0->N1M1N0M0->N1N0M1M0

The first step and the third part each transpose a tensor (such as the 0213 transposition scene below), which can generally be achieved very efficiently.

In the second step, the dimension of the single-element transposition is reduced from MN to M0N0, which is more conducive to the utilization of cache. For example, [4096,4096] can be split into [64,64,64,64] for transposition.

The visual display of this method is as follows: it is equivalent to splitting a large matrix into small matrices, transposing each small matrix independently, and then transposing the small matrix as a whole.

Note that this schematic shows M1M0N1N0->M1N1M0N0->M1N1N0M0->N1M1N0M0, which is one step away from the final goal.

Scene 2: 0213

Its characteristic is that the inner two adjacent dimensions interact and do not contain one or more innermost dimensions. As above, adjacent dimensions that are not exchanged can be combined as a whole, and the outermost dimension can be supplemented by 1 if it is insufficient.

In this scene perm=[2, 0, 3, 1, 4], it seems that multiple axes are transposed at the same time, but due to the particularity of shape element 1, it can be squeezed off, so it can be converted to [784, 3, 4, 12] to [3, 4, 784, 12] transpose, you can use the 0213 method to solve.

Delete the algorithm whose transpose shape is 1

perm = [2, 0, 3, 1, 4]
in_shape = [1, 784, 1, 4, 12]

rm_axes = []
for idx, elem in enumerate(in_shape):
    if elem == 1:
        rm_axes.append(idx)

print("rm_axes:", rm_axes)

def remove_axis(in_shape, perm, rm_axis):
    del in_shape[rm_axis]
    perm_rm_idx = -1
    for idx, elem in enumerate(perm):
        if elem == rm_axis:
            perm_rm_idx = idx
        if elem > rm_axis:
            perm[idx] = perm[idx]-1

    del perm[perm_rm_idx]

for rm_axis in reversed(rm_axes):
    remove_axis(in_shape, perm, rm_axis)

print("perm:", perm)
print("in_shape:", in_shape)

Scenario 3: Exchange two adjacent axes, but the shape corresponding to one of the axes is 1

This scene does not require transpose, only reshape is required.

After using the above algorithm to delete the transpose shape as 1, the perm of this transpose will become [0,1,2,3,...] It is very simple to judge that this transpose does not need to perform any operations, and delete it directly That's it.

Scenario 4: Exchange multiple axes, but some perms are adjacent

Here perm=[1, 2, 0], it seems that 3 axes are exchanged, but in fact the two 1x64 are exchanged together and can be merged into one dimension. This problem becomes the scenario 1 above. So a solution could be to merge adjacent axes that are transformed together, thus simplifying the problem.

Scene 5: Other

Of course, there are still a small number of scenarios that cannot be solved using the above method, for example, the case where the first dimension of the input shape is not 1.

Guess you like

Origin blog.csdn.net/u013701860/article/details/126738143