Matrix Power Series - 矩阵快速幂对分块矩阵加速

题目
其中\(A\)是一个\(n \times n\)的矩阵,\(S_k = A + A^2 + A^3 + … + A^k\),求\(S_k\)
按照数论出现和,那么构造一个和数论一起递推的式子\(S_k = S_{k - 1} + A^k\)
那么假设\(A\)不是矩阵

\[\left[\begin{array}{l} 1 & 1\\ 0 & A \end{array}\right] \times \left[\begin{array}{l} S_{k - 1}\\ A^k \end{array}\right] = \left[\begin{array}{l} S_k\\ A^{k + 1} \end{array}\right]\]

用E代替1,用矩阵代替数字,转换一下就是

\[\left[\begin{array}{l} E & E\\ 0 & A \end{array}\right] \times \left[\begin{array}{l} S_{k - 1}\\ A^k \end{array}\right] = \left[\begin{array}{l} S_k\\ A^{k + 1} \end{array}\right] \]

构建了个分块矩阵,大小是\(2n \times 2n\)
加速矩阵\(\left[\begin{array}{l} E & E\\ 0 & A \end{array}\right]\) 初始矩阵\(\left[\begin{array}{l} S_1\\ A^2 \end{array}\right]\)

#include <iostream>
#include <cstdio>
#include <cstring>
#define ll long long
using namespace std;
const int N = 61;
int mod, k, n;
struct Matrix{//矩阵
    int n,m;
    int a[N][N];
    Matrix(int x,int y):n(x),m(y){memset(a,0,sizeof(a));}
    Matrix operator * (const Matrix &b){
        Matrix ans(n,b.m);
        for(int i = 0; i < n; i++){
            for(int j = 0; j < b.m; j++){
                for(int k = 0; k < m; k++){
                    ans.a[i][j] = (ans.a[i][j] + a[i][k] * b.a[k][j] % mod) % mod;
                }
            }
        }
        return ans;
    }
};
Matrix ksm(Matrix a, ll b){
	Matrix ans(a.n, a.m);
	for(int i = 0; i <= max(a.n, a.m); i++)
		ans.a[i][i] = 1;

	while(b){
		if(b & 1)ans = ans * a;
		a = a * a;
		b >>= 1;
	}
	return ans;
}
Matrix a(35, 35);
void solve(){
	Matrix base(k * 2, k * 2);
	for(int i = 0; i < k; i++)
		base.a[i][i] = base.a[i][i + k] = 1;

	for(int i = k; i < 2 * k; i++)
		for(int j = k; j < 2 * k; j++)
			base.a[i][j] = a.a[i - k][j - k];

	base = ksm(base, n - 1);
	

	Matrix ans(k * 2, k);
	for(int i = 0; i < k; i++)
		for(int j = 0; j < k; j++)
			ans.a[i][j] = a.a[i][j];

	a = a * a;
	for(int i = k; i < 2 * k; i++)
		for(int j = 0; j < k; j++)
			ans.a[i][j] = a.a[i - k][j];

	ans = base * ans;
	for(int i = 0; i < k; i++){
		for(int j = 0; j < k; j++)
			printf("%d ", ans.a[i][j]);
		putchar('\n');
	}
}
int main(){
	scanf("%d%d%d", &k, &n, &mod);
	for(int i = 0; i < k; i++)
		for(int j = 0; j < k; j++)
			scanf("%d", &a.a[i][j]);
	solve();
	return 0;
}

猜你喜欢

转载自www.cnblogs.com/Emcikem/p/12891573.html
今日推荐