Matrix Power Series(构造矩阵+矩阵快速幂)

Matrix Power Series

Time Limit : 6000/3000ms (Java/Other)   Memory Limit : 262144/131072K (Java/Other)
Problem Description

Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.


Input
<span lang="en-us"><p>The input contains exactly one test case. The first line of input contains three positive integers <i>n</i> (<i>n</i> ≤ 30), <i>k</i> (<i>k</i> ≤ 10<sup>9</sup>) and <i>m</i> (<i>m</i> < 10<sup>4</sup>). Then follow <i>n</i> lines each containing <i>n</i> nonnegative integers below 32,768, giving <i>A</i>’s elements in row-major order.</p></span>
 

Output
<p>Output the elements of <i>S</i> modulo <i>m</i> in the same way as <i>A</i> is given.</p>
 

Sample Input
 
  
2 2 4 0 1 1 1
 

Sample Output
 
  
1 2 2 3
 

题意

给定一个矩阵A,求S = A + A2 + A3 + … + Ak.


思路

设S[k]=A + A2 + A3 + … + Ak.

那么S[n]=A*S[n-1]+A;


=>所求 S=(A,1)*(A,0 )^(k-1)

                          A,1


注意:矩阵套矩阵,里面的1代表单位矩阵,即对角线上元素为1,其余为0。


#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <algorithm>
using namespace std;

typedef long long ll;

typedef struct {
    ll m[120][120];
}Matrix;

ll  n,k,mod;

Matrix Mul(Matrix a, Matrix b)
{
    Matrix c;
    memset(c.m, 0, sizeof(c.m));
    for (int i = 0; i < 2*n; i++)
    {
        for (int j = 0; j < 2*n; j++)
        {
            for (int k = 0; k < 2*n; k++)
            {
                c.m[i][j] = (c.m[i][j] + (a.m[i][k] * b.m[k][j]) % mod ) % mod;
            }
        }
    }
    return c;
}
//矩阵乘法
Matrix solve(Matrix a, Matrix b)
{
    Matrix c;
    memset(c.m, 0, sizeof(c.m));
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < 2*n; j++)
        {
            for (int k = 0; k < 2*n; k++)
            {
                c.m[i][j] = (c.m[i][j] + (a.m[i][k] * b.m[k][j]) % mod ) % mod;
            }
        }
    }
    return c;
}

Matrix fastm(Matrix a, ll num)
{
    Matrix res;
    memset(res.m, 0, sizeof(res.m));
    //初始化为单位矩阵
    for(int i=0;i<2*n;i++)
        res.m[i][i]=1;
    while (num)
    {
        if (num & 1)
            res = Mul(res, a);
        num >>= 1;
        a = Mul(a, a);
    }
    return res;
}

int main()
{
    Matrix a;
    memset(a.m,0,sizeof(a.m));
    scanf("%lld%lld%lld",&n,&k,&mod);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            scanf("%lld",&a.m[i][j]);
            a.m[i][j]=a.m[i][j]%mod;
        }
    }
    Matrix b;
    memset(b.m,0,sizeof(b.m));
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
            b.m[i][j]=a.m[i][j];
    for(int i=0;i<n;i++)
        for(int j=n;j<2*n;j++)
            b.m[i][j]=0;
    for(int i=n;i<2*n;i++)
        for(int j=0;j<n;j++)
            b.m[i][j]=a.m[i-n][j];
    //注意:矩阵等于1表示该矩阵是单位矩阵,即主对角线为1,其余为0
    for(int i=n;i<2*n;i++)
            b.m[i][i]=1;
    for(int i=0;i<n;i++)
            a.m[i][i+n]=1;
    b=fastm(b,k-1);
    Matrix ans=solve(a,b);
    for(int i=0;i<n;i++)
    {
        printf("%lld",ans.m[i][0]);
        for(int j=1;j<n;j++)
        {
            printf(" %lld",ans.m[i][j]);
        }
        printf("\n");
    }
    return 0;
}


猜你喜欢

转载自blog.csdn.net/luyehao1/article/details/80470712
今日推荐