【学习记录】矩阵乘法

基本

一个 n × m n \times m m × p m \times p 的矩阵相乘的时间复杂度为 O ( n m p ) O(nmp) ,得到的结果为 n × p n\times p 的矩阵。

由于矩阵乘法有结合律,因此对于一个 n n 阶方阵 A A ,可以利用快速幂在 O ( n 3 log k ) O(n^3 \log k) 的时间内计算 A k A^k

Strassen 矩阵乘法

应用了分治的想法,把原来的 n × n n \times n n × n n \times n 的矩阵相乘,优化成了 7 个 n 2 × n 2 \frac{n}{2} \times \frac{n}{2} 的矩阵相乘。

由 Master Theorem, T ( n ) = 7 T ( n 2 ) + O ( n ) T(n) = 7T\left( \frac{n}{2}\right) + O(n) 的解为 O ( n log 2 7 ) O(n^{\log_2 7})

算法细节参见其他相关材料。

循环矩阵乘法

循环矩阵在计算一类和马尔可夫链有关的概率问题时会被用到。

n n 阶循环矩阵是满足下面条件的 n n 阶方阵:每一行都由上一行循环右移一位得到。如 3 3 阶循环矩阵:
[ a 1 a 2 a 3 a 3 a 1 a 2 a 2 a 3 a 1 ] \begin{bmatrix} a_1 & a_2 & a_3 \\ a_3 & a_1 & a_2 \\ a_2 & a_3 & a_1 \end{bmatrix}

可以证明,对于两个 n n 阶循环矩阵 A , B A, B ,满足 A + B , A B A+B, AB 都是循环矩阵。

对于循环矩阵而言,由于其每一行都是相同的,因此只需要一行就可以保存整个矩阵的信息。这使得循环矩阵的乘法的时间复杂度可以从一般的 O ( n 3 ) O(n^3) 降低到 O ( n 2 ) O(n^2)

我们使用第一行代表矩阵。对于两个 n n 阶循环矩阵 A , B A, B ,它们的第一行为 a , b a, b (base-0),那么 C = A B C=AB 的第一行 c c 满足:
c k = ( i + j ) m o d n = k a i b j c_{k} = \sum_{(i + j) \bmod n = k} a_i b_j

以行的形式保存时,用行向量来乘矩阵比较方便。设行向量为 a a ,要乘的矩阵是 B B ,那么结果向量 c c 恰好等于 C C 的第一行,其中 C C 是以 a a 为第一行的循环矩阵 A A B B 相乘的结果。

列向量的话较为繁琐,要经过两次转化。

例:牛牛的粉丝

题意:一个环上有 n n 个节点,第 i i 个节点上有 x i x_i 个人。每一轮每一个人独立行动,有 p 3 p_3 概率不动, p 1 p_1 概率顺时针移动到下一个点, p 2 p_2 概率逆时针移动到下一个点。问 k k 轮后每个点上人数的期望。(原题:牛客练习赛 68 D

可以看出,本题所表示的随机过程是一个马尔可夫链,且概率转移矩阵是一个循环矩阵。因此可以用快速幂计算最终的期望。时间复杂度为 O ( n 2 log k ) O(n^2 \log k)

由于答案就是每一个点各自答案的线性组合,因此快速幂的“底”直接使用了 x x ,而不是单位矩阵 I I

#include <bits/stdc++.h>
#define MOD 998244353
using namespace std;
typedef long long ll;
inline int modadd(int x, int y){
    return (x + y >= MOD ? x + y - MOD: x + y);
}
int poww(int a, int b){
    int res = 1;
    while (b > 0){
        if (b & 1) res = 1ll * res * a % MOD;
        a = 1ll * a * a % MOD, b >>= 1;
    }
    return res;
}
int n, a, b, c, x[505];
int mat[505], tmp[505];
ll k;
void mul(int *u, int *v){
    memset(tmp, 0, sizeof(tmp));
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
            tmp[(i + j) % n] = modadd(tmp[(i + j) % n], 1ll * u[i] * v[j] % MOD);
    for (int i = 0; i < n; ++i)
        u[i] = tmp[i];
}
void init(){
    scanf("%d%lld%d%d%d", &n, &k, &a, &b, &c);
    for (int i = 0; i < n; ++i)
        scanf("%d", &x[i]);
    int s = a + b + c;
    s = poww(s, MOD - 2);
    mat[0] = 1ll * c * s % MOD;
    mat[n - 1] = 1ll * b * s % MOD;
    mat[1] = 1ll * a * s % MOD;
}
void solve(){
    while (k > 0){
        if (k & 1ll) mul(x, mat);
        mul(mat, mat), k >>= 1;
    }
    for (int i = 0; i < n; ++i)
        printf("%d%c", x[i], (i == n - 1 ? '\n': ' '));
}
int main(){
    init();
    solve();
    return 0;
}

稀疏矩阵乘法

对于稀疏矩阵,可以在更低的时间复杂度内完成乘法运算。

我们用三元组来记录矩阵内的非零元素,即 ( r , c , v a l ) (r, c, val) 表示在矩阵第 r r 行第 c c 列的值为 v a l val 。那么对于两个用二元组列表表示的稀疏矩阵,只要计算它们之间每一对三元组 v a l val 的乘积,并将乘积累加到结果的对应位置即可。

F 2 \mathbb{F}_2 下的矩阵乘法

可以用 bitset 进行加速。

对于两个位向量,它们的内积就是与运算之后向量中 1 的个数。

F 3 \mathbb{F}_3 下的矩阵乘法

也可以用 bitset 进行加速。

可以用两个位向量表示一个 F 3 n \mathbb{F}_3^n 向量1,利用这种向量内积的快速计算来加速矩阵乘法。具体运算方法可以参考下面的代码。

例:HDU 4920

本题是一个 F 3 \mathbb{F}_3 下的矩阵乘法模板题,实现如下。

#include <bits/stdc++.h>
using namespace std;
int n, a[805][805];
bitset<805> r1[805], r2[805];
bitset<805> c1[805], c2[805];
bitset<805> tmp[4];
void init(){
    for (int i = 0; i < n; ++i){
        r1[i].reset();
        r2[i].reset();
        for (int j = 0, t; j < n; ++j){
            scanf("%d", &t);
            t %= 3;
            if (t == 1) r1[i][j] = 1;
            else if (t == 2) r2[i][j] = 1;
        }
    }
    for (int i = 0; i < n; ++i)
        for (int j = 0, t; j < n; ++j)
            scanf("%d", &t), a[i][j] = t % 3;
    for (int j = 0; j < n; ++j){
        c1[j].reset();
        c2[j].reset();
        for (int i = 0; i < n; ++i){
            if (a[i][j] == 1) c1[j][i] = 1;
            else if (a[i][j] == 2) c2[j][i] = 1;
        }
    }
}
void solve(){
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j){
            tmp[0] = r1[i] & c2[j];
            tmp[1] = r2[i] & c1[j];
            tmp[2] = r1[i] & c1[j];
            tmp[3] = r2[i] & c2[j];
            tmp[0] |= tmp[1];
            tmp[2] |= tmp[3];
            int res = (2 * tmp[0].count() + tmp[2].count()) % 3;
            printf("%d%c", res, (j == n - 1 ? '\n': ' '));
        }
}
int main(){
    while (scanf("%d", &n) == 1){
        init();
        solve();
    }
    return 0;
}

参考资料

  1. C++位运算入门

猜你喜欢

转载自blog.csdn.net/zqy1018/article/details/108328737