hdu6395 Sequence(分段矩阵快速幂)

Sequence

题目传送门

解题思路

可以比较容易的推出矩阵方程,但是由于p/i向下取整的值在变,所以要根据p/i的变化将矩阵分段快速幂。p/i一共有sqrt(p)种结果,所以最多可以分为sqrt(p)段进行快速幂。

代码如下

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;

inline int read(){
    int res = 0, w = 0; char ch = 0;
    while(!isdigit(ch)){
        w |= ch == '-', ch = getchar();
    }
    while(isdigit(ch)){
        res = (res << 3) + (res << 1) + (ch ^ 48);
        ch = getchar();
    }
    return w ? -res : res;
}

const int N = 100005;
const int mod = 1e9+7;
struct Matrix{
    ll m[3][3];
    Matrix(){
        memset(m, 0, sizeof(m));
    }
};

Matrix mul(Matrix& x, Matrix& y)
{
    Matrix ans;
    for(int i = 0; i < 3; i ++){
        for(int j = 0; j < 3; j ++){
            for(int k = 0; k < 3; k ++)
                ans.m[i][j] = (x.m[i][k] * y.m[k][j] % mod + ans.m[i][j]) % mod;
        }
    }
    return ans;
}

Matrix sq_pow(Matrix& x, int k)
{
    Matrix t = x;
    Matrix ans;
    ans.m[0][0] = 1, ans.m[1][1] = 1, ans.m[2][2] = 1;
    while(k){
        if(k & 1)
            ans = mul(ans, t);
        t = mul(t, t);
        k >>= 1;
    }
    return ans;
}

int main()
{
    int t;
    cin >> t;
    while(t --){
        ll a, b;
        int c, d, p, n;
        cin >> a >> b >> c >> d >> p >> n;
        Matrix t;
        t.m[0][0] = d, t.m[0][1] = c, t.m[1][0] = 1, t.m[2][2] = 1;
        if(n == 1)
            printf("%lld\n", a % mod);
        else if(n == 2)
            printf("%lld\n", b % mod);
        else {
            for(int i = 3; i <= n; ){
                //printf("b: %lld\n", b % mod);
                t.m[0][2] = p / i;
                if(p / i == 0){
                    Matrix temp = sq_pow(t, n - i + 1);
                    b = (b * temp.m[0][0] + a * temp.m[0][1] + temp.m[0][2]) % mod;
                    break;
                }
                else {
                    int j = p / (p / i);
                    j = min(j, n);
                    Matrix temp = sq_pow(t, j - i + 1);
                    ll tb = (b * temp.m[0][0] + a * temp.m[0][1] + temp.m[0][2]) % mod;
                    ll ta = (b * temp.m[1][0] + a * temp.m[1][1] + temp.m[1][2]) % mod;
                    a = ta, b = tb;
                    i = j + 1;
                }
            }
            printf("%lld\n", b % mod);
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/whisperlzw/p/11211908.html