MNN 中的矩阵乘法

背景

之前也写过sgemm,然后就想看看MNN是如何实现的,有没有什么可以借鉴的地方,看完之后发现MNN的实现也是简单的按行数据并行处理,记录一下。

矩阵乘法

矩阵乘法的目的是完成一个计算:C = A * B,其中A是h * k, B是k * w,所以C是h * w。
在这里插入图片描述
常用的方式是分行处理,对于C的第一行,可以按如下方式处理:

C(0,j) += A(0,i)*B(i,j)

对于行主序矩阵,每一行的数据是连续存储的,我们自然可以考虑使用SIMD指令,一次处理4个(假设是Float32)数据的相乘:

float32x4_t a0   = vdupq_n_f32(aLine[i]);
float32x4_t b0   = vld1q_f32(bLine);
float32x4_t sum0 = vdupq_n_f32(0.0);
sum0             = vmlaq_f32(sum0, a0, b0);
vst1q_f32(cLine, sum0);

需要注意的一点是,如果w不能被4整除,那么需要处理边界,逐个点进行计算并赋值:

C(0,j) += A(0,i) * B(i,j)

MNN的实现

void Matrix::multi(Tensor* C, const Tensor* A, const Tensor* B) {
    MNN_ASSERT(NULL != C);
    MNN_ASSERT(NULL != B);
    MNN_ASSERT(NULL != A);

    MNN_ASSERT(2 == C->dimensions());
    MNN_ASSERT(2 == B->dimensions());
    MNN_ASSERT(2 == A->dimensions());

    const auto a = A->host<float>();
    const auto b = B->host<float>();
    auto c       = C->host<float>();

    const int h = A->length(0);
    const int k = A->length(1);
    const int w = B->length(1);

    const int aw = A->stride(0);
    const int bw = B->stride(0);
    const int cw = C->stride(0);

    MNN_ASSERT(k == B->length(0));

    int y = 0;
    for (; y < h; ++y) {
        int x            = 0;
        const auto aLine = a + y * aw;
        auto cLine       = c + y * cw;
#ifdef MNN_USE_NEON
        // firstly, compute 16 together
        for (; x <= w - 16; x += 16) {
            auto bColumn     = b + x;
            float32x4_t sum0 = vdupq_n_f32(0.0);
            float32x4_t sum1 = vdupq_n_f32(0.0);
            float32x4_t sum2 = vdupq_n_f32(0.0);
            float32x4_t sum3 = vdupq_n_f32(0.0);
            for (int i = 0; i < k; ++i) {
                const auto bLine = bColumn + i * bw;
                float32x4_t a0   = vdupq_n_f32(aLine[i]);
                float32x4_t b0   = vld1q_f32(bLine);
                float32x4_t b1   = vld1q_f32(bLine + 4);
                float32x4_t b2   = vld1q_f32(bLine + 8);
                float32x4_t b3   = vld1q_f32(bLine + 12);
                sum0             = vmlaq_f32(sum0, a0, b0);
                sum1             = vmlaq_f32(sum1, a0, b1);
                sum2             = vmlaq_f32(sum2, a0, b2);
                sum3             = vmlaq_f32(sum3, a0, b3);
            }
            vst1q_f32(cLine + x, sum0);
            vst1q_f32(cLine + x + 4, sum1);
            vst1q_f32(cLine + x + 8, sum2);
            vst1q_f32(cLine + x + 12, sum3);
        }
        // secondly, compute 4 together
        for (; x <= w - 4; x += 4) {
            auto bColumn    = b + x;
            float32x4_t sum = vdupq_n_f32(0.0);
            for (int i = 0; i < k; ++i) {
                const auto bLine = bColumn + i * bw;
                float32x4_t a4   = vdupq_n_f32(aLine[i]);
                float32x4_t b4   = vld1q_f32(bLine);
                sum              = vmlaq_f32(sum, a4, b4);
            }
            vst1q_f32(cLine + x, sum);
        }
#endif
        for (; x < w; ++x) {
            auto bColumn = b + x;
            float sum    = 0.0f;
            for (int i = 0; i < k; ++i) {
                sum += aLine[i] * bColumn[i * bw];
            }
            cLine[x] = sum;
        }
    }
}

关键部分是MNN_USE_NEON宏包裹的部分,具体的思路,对输出矩阵C进行循环,因为是行主序(每一行连续存储),所以按行来进行计算,只不过它这里,先按16循环,可以利用流水线,提升效率,然后对于小于16的部分,先4个一组处理,对于小于4的边界部分,逐点处理。

评论

MNN的矩阵乘法实现可以说是标准的sgemm的一种简单加速版本,但是有两个点需要进一步考虑,一是没有做pack,对于大矩阵,这种直接的vld1q_f32可能会导致大量的cache miss,二是既然已经按行分别处理,且每一行的写入过程相互独立,所以可以考虑增加多线程来提高行间效率,可以使用Openmp或者自己起两个Thread来进行并行处理。

发布了42 篇原创文章 · 获赞 33 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/gaussrieman123/article/details/102798762