HDU5322 - cdq分治FFT加速dp

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_37517391/article/details/83218581

5322 Hope [CDQ分治FFT加速计算dp]


题意

每一个每一个排列,排列中每个数向它后面第一个比它大的数连一条边.

每个排列对于答案的贡献是这个排列所生成的图中的每一个联通量中点的个数的平方之积.

例如:排列

1 , 2 , 3 , 6 , 4 , 5 1,2,3,6,4,5

其中

1 , 2 , 3 , 6 1,2,3,6 形成一个大小为 4 4 的联通分量.

4 , 5 4,5 形成一个大小为 2 2 的联通分量.

那么这个排列的贡献就是 4 2 2 2 = 64 4^2*2^2 = 64 .

我们需要求所有排列的贡献.

题解

这种题做法一般就是 d p dp .

我们枚举排列中最大的数 n n 的所在的位置,当 n n 的位置固定以后,所有在 n n 前面的数都会与 n n 形成一个连通分量.而后面的所有的数就组成了一个子问题.

定义 d p [ n ] dp[n] 表示 n n 的所有排列所形成的贡献之和.

我们可以得到递推公式.

d p [ n ] = i = 1 n C n 1 i 1 ( i 1 ) ! i 2 d p [ n i ] dp[n] = \sum_{i=1}^{n}C_{n-1}^{i-1}*(i-1)!*i^2*dp[n-i]

化简得到

d p [ n ] = ( n 1 ) ! i = 1 n i 2 d p [ n i ] ( n i ) ! dp[n] = (n-1)!\sum_{i=1}^{n}i^2*\frac{dp[n-i]}{(n-i)!}

F [ i ] = d p [ i ] i ! , G [ i ] = i 2 F[i]=\frac{dp[i]}{i!},G[i]=i^2

那么

d p [ n ] = ( n 1 ) ! i = 1 n G [ i ] F [ n i ] dp[n] = (n-1)!\sum_{i=1}^{n}G[i]F[n-i]

观察到 d p [ n ] dp[n] 的递推公式是一个卷积的形式.

求卷积我们显然想到了 F F T / N T T FFT/NTT 来加速 d p dp 的计算,对于要求出所有的 d p [ n ] dp[n] ,因此我们需要用 C D Q CDQ 分治来辅助计算 F F T / N T T FFT/NTT .

采用分治策略,即将区间 F [ l , r ] F[l,r] 分成区间 F [ l , m i d ] F[l,mid] F [ m i d + 1 , r ] F[mid+1,r] 两个区间.

F [ l , m i d ] F[l,mid] 区间的内容和 G G 多项式做卷积,即可计算出 F [ l , m i d ] F[l,mid] 部分对于 F [ m i d + 1 , r ] F[mid+1,r] 的影响.注意!这里要求 F [ l , m i d ] F[l,mid] 必须已经完整的被计算出来了.随后再地归计算 F [ m i d + 1 , r ] F[mid+1,r] 部分,解决它们内部的贡献以来.

因此代码的大致框架就是:

void solve(int l,int r) {
	if(l == r) return ;
	int mid = (l + r) >> 1;
	solve(l,mid);//保证区间F[l,mid]已经被完整的计算出来了
	Conv();//求卷积,计算F[l,mid]对F[mid+1,r]的贡献.注意此时F[1,l-1]对F[mid+1,r]的贡献早已被求过了.
	//注意在此时,F[mid+1,r]内部依赖于F[1,mid]的贡献已经全部被计算出来了,因此下一步求内部F[mid+1,r]对F[mid+1,r]的贡献.
	solve(mid+1,r);
}

举个栗子

n = 4 n=4 时刻
初始 [ 1 , 1 ] [1,1] 就算完整的被计算出来了.
先求 [ 1 , 1 ] [1,1] [ 2 , 2 ] [2,2] 的贡献,从而得到了完整的 [ 1 , 2 ] [1,2] .
然后计算 [ 1 , 2 ] [1,2] [ 3 , 4 ] [3,4] 的贡献,此时 [ 3 , 3 ] [3,3] 已经被完整的计算出来了.
然后算 [ 3 , 3 ] [3,3] 对于 [ 4 , 4 ] [4,4] 的贡献,导致 [ 4 , 4 ] [4,4] 被完整的计算出来.
至此 [ 1 , 4 ] [1,4] 都被完整的计算出来了.

注意

其中计算 [ l , m i d ] [l,mid] [ m i d + 1 , r ] [mid+1,r] 的贡献的时候,需要注意:

x [ m i d + 1 , r ] x \in [mid+1,r]

F [ x ] F [ l + 0 ] G [ x l ] + . . . + F [ l + ( m i d l ) ] G [ x m i d ] F[x] \leftarrow F[l+0]*G[x-l] + ... + F[l+(mid-l)]*G[x-mid]

F [ l ] , F [ l + 1 ] , . . . , F [ m i d ] F[l],F[l+1],...,F[mid] 填在数组 a [ 0 ] , a [ 1 ] , . . . , a [ m i d l ] a[0],a[1],...,a[mid-l] 的位置.

G [ 0 ] , G [ 1 ] , G [ 2 ] , . . . , G [ n ] G[0],G[1],G[2],...,G[n] 填在数组 b [ 0 ] , b [ 1 ] , b [ 2 ] , . . , b [ n ] b[0],b[1],b[2],..,b[n] 的位置.
然后求个卷积 c = a b c = a\bigotimes b

最后卷积数组的 c [ x l ] c[x-l] 位置的值就是 F [ x ] F[x] .

描述上稍微有点问题, F F d p dp 的关系有点混淆,诸位看我代码就ok了.

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#define pr(x) std::cout << #x << ':' << x << std::endl
#define rep(i,a,b) for(int i = a;i <= b;++i)
#define clr(x) memset(x,0,sizeof(x))
#define setinf(x) memset(x,0x3f,sizeof(x))
#define Max(x,y) x = std::max(x,y)
#define Min(x,y) x = std::min(x,y)
#define Add(x,y) x = (((x)+(y))%P)
#define Sub(x,y) x = (((x)-(y)+P%P)
#define Mul(x,y) x = ((x)*(y)%P)
typedef long long LL;
const int N = 1 << 20;
const int P = 998244353;
const int G = 3;
const int NUM = 20;

LL  wn[NUM];
LL  a[N], b[N];

LL quick_mod(LL a, LL b, LL m)
{
    LL ans = 1;
    a %= m;
    while(b)
    {
        if(b & 1)
        {
            ans = ans * a % m;
            b--;
        }
        b >>= 1;
        a = a * a % m;
    }
    return ans;
}

void GetWn()
{
    for(int i = 0; i < NUM; i++)
    {
        int t = 1 << i;
        wn[i] = quick_mod(G, (P - 1) / t, P);
    }
}
void Rader(LL a[], int len)
{
    int j = len >> 1;
    for(int i = 1; i < len - 1; i++)
    {
        if(i < j) std::swap(a[i], a[j]);
        int k = len >> 1;
        while(j >= k)
        {
            j -= k;
            k >>= 1;
        }
        if(j < k) j += k;
    }
}

void NTT(LL a[], int len, int on)
{
    Rader(a, len);
    int id = 0;
    for(int h = 2; h <= len; h <<= 1)
    {
        id++;
        for(int j = 0; j < len; j += h)
        {
            LL w = 1;
            for(int k = j; k < j + h / 2; k++)
            {
                LL u = a[k] % P;
                LL t = w * a[k + h / 2] % P;
                a[k] = (u + t) % P;
                a[k + h / 2] = (u - t + P) % P;
                w = w * wn[id] % P;
            }
        }
    }
    if(on == -1)
    {
        for(int i = 1; i < len / 2; i++)
            std::swap(a[i], a[len - i]);
        LL inv = quick_mod(len, P - 2, P);
        for(int i = 0; i < len; i++)
            a[i] = a[i] * inv % P;
    }
}

void Conv(LL a[], LL b[], int n)
{
    NTT(a, n, 1);
    NTT(b, n, 1);
    for(int i = 0; i < n; i++)
        a[i] = a[i] * b[i] % P;
    NTT(a, n, -1);
}

LL dp[N],Fac[N],iFac[N];

void solve(int l,int r) {
    if(l == r) 
        return ;

    int mid = (l + r) >> 1;
    
    solve(l,mid);
    
    int len = 1;
    while(len <= (r-l+1)) len <<= 1;
    rep(i,0,len) {
        b[i] = 1LL*i*i%P;
    }

    rep(i,l,mid) {
        a[i-l] = dp[i]*iFac[i]%P;
    }

    rep(i,mid-l+1,len) a[i] = 0;

    Conv(a,b,len);

    rep(i,mid+1,r) {
        dp[i] = (dp[i] + (a[i-l]*Fac[i-1]%P))%P;
    }


    solve(mid+1,r);
}

int main()
{
    
    Fac[0] = iFac[0] = 1;
    rep(i,1,N-1) {
        Fac[i] = Fac[i-1]*i % P;
        iFac[i] = quick_mod(Fac[i],P-2,P);
    }
    GetWn();
    dp[0] = 1;
    
    solve(0,100000);

    int n;
    while(std::cin >> n) {
        std::cout << dp[n] << std::endl;
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_37517391/article/details/83218581
今日推荐