POJ3233--Matrix Power Series--(矩阵加速,嵌套矩阵)


目录

一.题目

题目描述

输入

输出

样例输入

样例输出

二.题解

三.代码

谢谢!


一.题目

题目描述

给定矩阵A,求矩阵S=A^1+A^2+……+A^k,输出矩阵,S矩阵中每个元都要模m。

数据范围: n (n ≤ 30) ,  k (k ≤ 109) ,m (m < 104)

输入

输入三个正整数n,k,m

输出

输出矩阵S mod m

样例输入

2 2 4

0 1

1 1

样例输出

1 2

2 3

二.题解

这道题目的矩阵其实特别好推:(E代表单位矩阵,O代表0矩阵,S代表答案矩阵)

对吧,但是构造矩阵里面的是一些矩阵呀,A,O,E都是矩阵。

那么这里就有两种方法,一种是直接用嵌套矩阵加速,和一般的矩阵加速道理一样,但是特别难写。

另一种方法:因为题目的k并不大,所以我们可以把这些矩阵打开,就变成了一般的矩阵加速。

三.代码

//第二种方法
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define M 65
#define LL long long
LL n, k, m;
struct node {
    LL R, c1, s[M][M];
    node operator * (const node& rhs){
        node ans;
        ans.R = R, ans.c1 = rhs.c1;
        for (int i = 0; i < M; i ++)
            for (int j = 0; j < M; j ++)
                ans.s[i][j] = 0;
        for (int i = 1; i <= R; i ++)
            for (int j = 1; j <= rhs.c1; j ++)
                for (int k = 1; k <= c1; k ++)
                    ans.s[i][j] = (ans.s[i][j] + s[i][k] * rhs.s[k][j] % m) % m;
        return ans;
    }
}A, B, C;
node qkpow (node x, LL y){
    node ans;
    ans.R = ans.c1 = 2 * n;
    for (int i = 1; i <= ans.R; i ++)
        for (int j = 1; j <= ans.c1; j ++)
            if (i == j)
                ans.s[i][j] = 1;
            else
                ans.s[i][j] = 0;
    while (y > 0){
        if (y % 2 == 1)
            ans = ans * x;
        x = x * x;
        y /= 2;
    }
    return ans;
}
int main (){
    scanf ("%lld %lld %lld", &n, &k, &m);
    memset (A.s, 0, sizeof A.s);
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= n; j ++){
            scanf ("%lld", &A.s[i + n][j]);
            A.s[i + n][j + n] = A.s[i + n][j];
            B.s[i][j + n] = A.s[i + n][j];
            B.s[i][j] = A.s[i + n][j];
        }
    A.R = A.c1 = 2 * n;
    B.R = n, B.c1 = 2 * n;
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= n; j ++)
            if (i == j) A.s[i][j] = 1;
            else    A.s[i][j] = 0;
    C = B * qkpow (A, k - 1);
    for (int i = 1; i <= n; i ++){
        int j;
        for (j = 1; j < n; j ++)
            printf ("%lld ", C.s[i][j]);
        printf ("%lld\n", C.s[i][j]);
    }
    return 0;
}

谢谢!

发布了61 篇原创文章 · 获赞 32 · 访问量 8365

猜你喜欢

转载自blog.csdn.net/weixin_43908980/article/details/89311033