transpose算子优化的几种常见场景

很难写一个kernel就能同时在transpose的所有场景都最优,归纳transpose的几种常见场景可以针对性优化。这里只列出了transpose对轴变换的几种情况,没有考虑shape大小。因此在这几种场景上还应该考虑转置的轴 shape大小针对性优化。

场景1: batch 2D,perm:021

二维,三维,或者更高维的tensor,交换最内层的两个维度。这些都可以统一为batch 2D。对于2D矩阵的转置相当于batch=1,大于三维的tensor可以把两个最内层维度以外的所有维度合并看成一个维度。

一种大的021矩阵转置变换为小的矩阵转置方法:

比如输入是[4096,4096]的大矩阵进行转置,但是每个线程读取输入同一行的相邻位置,但是写出时写的是同一列的相邻位置,对于写的话,同一列跨的行长度太长对写回cache命中不友好,可以考虑降低跨的行长度来优化性能。

可以考虑把大的转置拆分为几个小的转置:

例如把MN转置为NM,考虑M和N都可以拆分为更小的维度:M拆分为M1M0,N拆分为N1N0

那么问题有MN转置为NM变为M1M0N1N0转置为N1N0M1M0

假设一次只能交换两个轴,那么可以通过如下三个步骤来实现:

M1M0N1N0->N1M1M0N0->N1M1N0M0->N1N0M1M0

扫描二维码关注公众号,回复: 15754062 查看本文章

其中第一步而第三部每次转置的都是一个tensor(例如下面的0213转置场景),这个一般都能实现的非常高效。

而第二步单元素的转置的维度从MN减小到了M0N0,更有利于缓存的利用。例如[4096,4096]可以拆分为[64,64,64,64]的大小来进行转置。

该方法可视化展示如下:相当于把大矩阵拆分为小矩阵,每个小矩阵独立转置,再把小矩阵看成一个整体转置一下。

要注意这个示意图中展示的是M1M0N1N0->M1N1M0N0->M1N1N0M0->N1M1N0M0,离最终要的还差一步。

场景2:0213

其特点是内部的两个相邻的维度进行交互,不包含最内层的一个或多个维度。跟上面一样,相邻不交换的维度可以合并看成一个整体,最外层的维度不足可以补1。

这个场景perm=[2, 0, 3, 1, 4],看上去同时转置了多个axes,但是由于shape元素1的特殊性,可以squeeze掉, 因此可以转换为[784, 3, 4, 12]到[3, 4, 784, 12]的transpose,可以使用0213的方法来解决。

删除transpose shape为1的算法

perm = [2, 0, 3, 1, 4]
in_shape = [1, 784, 1, 4, 12]

rm_axes = []
for idx, elem in enumerate(in_shape):
    if elem == 1:
        rm_axes.append(idx)

print("rm_axes:", rm_axes)

def remove_axis(in_shape, perm, rm_axis):
    del in_shape[rm_axis]
    perm_rm_idx = -1
    for idx, elem in enumerate(perm):
        if elem == rm_axis:
            perm_rm_idx = idx
        if elem > rm_axis:
            perm[idx] = perm[idx]-1

    del perm[perm_rm_idx]

for rm_axis in reversed(rm_axes):
    remove_axis(in_shape, perm, rm_axis)

print("perm:", perm)
print("in_shape:", in_shape)

场景3:交换两个相邻的axes,但是其中一个axis对应的shape是1

这个场景并不需要transpose,只需要reshape即可。

使用上面的删除transpose shape为1的算法后,这种transpose的perm会变成[0,1,2,3,...] 可以非常简单的判断这个transpose实际上不需要进行任何操作,直接删除即可。

场景4:交换多个axes,但是部分perm是相邻的

这里perm=[1, 2, 0], 看上去交换了3个axes,实际上1x64这两个是一起交换的,可以合并成一个维度,这个问题就变成了上面的场景1。因此解决方案可以是合并一起变换的相邻轴,从而把问题简化。

场景5:其他

当然还有少量场景无法使用上面的方法来解决,例如这里输入shape第一个维度不是1的情况。

猜你喜欢

转载自blog.csdn.net/u013701860/article/details/126738143