[Explanations] [SDOI2015] sequence statistics

Face questions

answer

Set \ (f [i] [j ] \) representative length of \ (I \) sequence, product modulo \ (m \) of \ (J \) is the number of sequence

Transfer equation as
\ [f [i + j]
[C] = \ sum_ {A * B \ equiv C \ pmod {m}} f [i] [B] * f [j] [A] \] Complexity is \ (O (nm ^ 2) \) of

Consider multiplication, similar things fast power as
\ [f [2 * i] [C] = \ sum_ {A * B \ equiv C \ pmod {m}} f [i] [B] * f [i] [ a] \]
Well, the complexity becomes \ (O (m ^ 2logn) \) of

Continue to optimize

Something equivalent to a formula, see the place
\ [c [z] = \
sum_ {x * y \ equiv z \ pmod m} a [x] b [y] \] If this is a form of
\ [C [z] = \ sum_ {x
+ y = z} a [x] b [y] \] we can use to optimize the NTT

We know that the number of multiplications can be converted into an addition

But the number is a real number, we need to consider the number of mold in a sense

The primitive root as base number on it, so we will formula into
\ [c [log_gz] = \
sum_ {log_gx + log_gy \ equiv log_gz \ pmod m} a [log_gx] b [log_gy] \] Considering the \ ( log_gx + log_gy \) may be larger than \ (m \)

However, it must not be greater than \ (2M \) , so we have to \ (c [z] \) in this position, plus \ (c [z + m -. 1] \) , then \ (c [z + m --1] \) is cleared to

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <map>
const int N = 40005;
const int mod = 1004535809; 
using namespace std;

int n, m, X, S, lim, cnt, r[N], g, gg, a[N], b[N], res[N], f[N], top, fact[20005]; 
map<int, int> mp; 

template < typename T >
inline T read()
{
    T x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * w; 
}

int fpow(int x, int y, int p)
{
    int res = 1;
    for( ; y; y >>= 1, x = 1ll * x * x % p)
        if(y & 1) res = 1ll * res * x % p;
    return res; 
}

int getroot(int x)
{
    top = 0; 
    int rem = x - 1, p = rem; 
    for(int i = 2; i * i <= x; i++)
        if(!(rem % i))
        {
            fact[++top] = i;
            while(!(rem % i)) rem /= i; 
        }
    if(rem > 1) fact[++top] = rem; 
    for(int flag = 1, i = 2; i <= p; i++, flag = 1)
    {
        for(int j = 1; j <= top && flag; j++)
            if(fpow(i, p / fact[j], x) == 1) flag = 0;
        if(flag) return i; 
    }
    return -1; 
}

void ntt(int *p, int opt)
{
    for(int i = 0; i < lim; i++) if(i < r[i]) swap(p[i], p[r[i]]);
    for(int i = 1; i < lim; i <<= 1)
    {
        int rt = fpow(opt == 1 ? g : gg, (mod - 1) / (i << 1), mod);
        for(int j = 0; j < lim; j += (i << 1))
        {
            int w = 1;
            for(int k = j; k < j + i; k++, w = 1ll * w * rt % mod)
            {
                int x = p[k], y = 1ll * w * p[k + i] % mod;
                p[k] = (1ll * x + y) % mod, p[k + i] = (1ll * x - y + mod) % mod; 
            }
        }
    }
    if(opt == -1)
    {
        int inv = fpow(lim, mod - 2, mod);
        for(int i = 0; i < lim; i++) a[i] = 1ll * a[i] * inv % mod; 
    }
}

void mul(int *A, int *B, int *C)
{
    for(int i = 0; i < lim; i++) a[i] = A[i], b[i] = B[i];
    ntt(a, 1), ntt(b, 1);
    for(int i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
    ntt(a, -1); 
    for(int i = 0; i < m - 1; i++) a[i] = (1ll * a[i] + a[i + m - 1]) % mod, a[i + m - 1] = 0;
    for(int i = 0; i < lim; i++) C[i] = a[i]; 
}

int main()
{
    n = read <int> (), m = read <int> (), X = read <int> (), S = read <int> (); 
    g = getroot(m), gg = fpow(g, m - 2, m); 
    for(int tmp = 1, i = 0; i < m - 1; i++, tmp = 1ll * tmp * g % m) mp[tmp] = i; 
    for(int x, i = 1; i <= S; i++)
    {
        x = read <int> (); 
        if(x) f[mp[x]]++; 
    }
    res[mp[1]] = 1; 
    for(lim = 1; lim <= 2 * m; lim <<= 1, cnt++); cnt--;
    for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << cnt);
    g = getroot(mod), gg = fpow(g, mod - 2, mod); 
    while(n)
    {
        if(n & 1) mul(res, f, res); 
        mul(f, f, f); 
        n >>= 1; 
    }
    printf("%d\n", res[mp[X]]); 
    return 0; 
}

Guess you like

Origin www.cnblogs.com/ztlztl/p/11989296.html