HDU 5322 Hope (分治 + NTT)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5322

题目大意:对于1~n的全排列中的任意一个排列,对于在排列中的任意一个 i,如果存在一个离 i 最近的 j 满足 i < j 且 A[i] < A[j],就在 i 和 j 之间建一条无向边。照此规矩建完边之后,如果联通图内的点数为P,那么这一个联通图的贡献就是P*P,种排列的贡献就为所有联通图的贡献的乘积。现在给你一个n,要你求出n的所有排列的贡献之和。

题目思路:由于这个题目的组数特别多,所以我们考虑用dp预处理来解决这个问题。

dp[i] 表示 长度为 i 的排列的贡献之和。

我们考虑对于长度为 i 的序列中的某一个序列当 最大的数 i 位于第 j 位时,前 j - 1 个数必然会与 i 形成联通块,后面的 i - j 个数就不会与前面的点联通。这样我们就可以推出如下的状态转移方程:

dp[i] = \sum_{j = 1}^{i}C(i-1,j-1)*(j-1)!*j^2*dp[i-j]

表示从除了最大值 i 以外的 i - 1 个数中选出 j - 1 一个放到 前 j - 1个位置里,同时这个j - 1个数总共有 (j - 1) ! 种排列方式,这个联通块对答案的贡献就为 j^2,再乘上后面 i - j 个数自己形成联通块的贡献之和就是最终答案了。

我们接着再用组合数的性质化简一下这个式子

C(i-1,j-1)=\frac{(i-1)!}{(j-1)!*(i-j)!}

dp[i] = (i-1)!*\sum_{j = 1}^{i}(i-j)!*dp[i-j]*j^2

接着令A[i-j] = (i-j)!*dp[i-j]B[j]=j^2

式子就化成了dp[i]=(i-1)!*\sum_{j+k=i}^{ }A[j]*B[k]

这个式子就能用FFT来求卷积了(本题由于有模数,所以可以选择用NTT来降低精度的误差)。

但由于是要预处理出所有的dp的值,所以我们可以选择用分治来降低复杂度,这样总的复杂度就为O(n*logn*logn+T)。

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define debug(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int,int>pii;
const int MX = 1e5+7;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;

const int P = 998244353;
const int G = 3;
const int NUM = 20;

ll  wn[NUM];
ll  va[MX<<2],vb[MX<<2];
ll quick_mod(ll a, ll x, ll mod) {
    ll ans = 1;
    a %= mod;
    while(x) {
        if(x & 1)ans = ans * a % mod;
        x >>= 1;
        a = a * a % mod;
    }
    return ans;
}
//在程序的开头就要放
void GetWn() {
    for(int i = 0; i < NUM; i++) {
        int t = 1 << i;
        wn[i] = quick_mod(G, (P - 1) / t, P);
    }
}
void Rader(ll F[], int len) {
    int j = len >> 1;
    for(int i = 1; i < len - 1; i++) {
        if(i < j) swap(F[i], F[j]);
        int k = len >> 1;
        while(j >= k)j -= k,k >>= 1;
        if(j < k) j += k;
    }
}
void NTT(ll F[], int len, int t) {
    Rader(F, len);
    int id = 0;
    for(int h = 2; h <= len; h <<= 1) {
        id++;
        for(int j = 0; j < len; j += h) {
            ll E = 1;
            for(int k = j; k < j + h / 2; k++) {
                ll u = F[k];
                ll v = E * F[k + h / 2] % P;
                F[k] = (u + v) % P;
                F[k + h / 2] = (u - v + P) % P;
                E = E * wn[id] % P;
            }
        }
    }
    if(t == -1) {
        for(int i = 1; i < len / 2; i++)swap(F[i], F[len - i]);
        ll inv = quick_mod(len, P - 2, P);
        for(int i = 0; i < len; i++)F[i] = F[i] * inv % P;
    }
}
void Conv(ll a[], ll b[], int len) {
    NTT(a, len, 1);
    NTT(b, len, 1);
    for(int i = 0; i < len; i++)a[i] = a[i] * b[i] % P;
    NTT(a, len, -1);
}
ll dp[MX],f[MX],invf[MX];
void init(){
	f[0] = f[1] = 1;
	for(int i = 2;i < MX;i++) f[i] = (f[i-1] * i) % P;
	invf[MX-1] = quick_mod(f[MX-1],P-2,P);
	for(int i = MX-2;i >= 0; i--) invf[i] = (invf[i+1] * (i+1)) % P;
}
void cdq(int l,int r){
	if(l == r) return;
	int m = (l + r) >> 1;
	cdq(l,m);
	int mx = (r - l + 1),len = 1;
	while(len <= mx) len <<= 1;
	for(int i = 0;i < len;i++){
		if(l + i <= m) va[i] = dp[l+i]*invf[l+i] % P;
		else va[i] = 0;
		if(l + i < r) vb[i] = (ll)(i+1)*(i+1) % P;
		else vb[i] = 0;
	}
	Conv(va,vb,len);
	for(int i = m + 1;i <= r;i++)
		dp[i] = (dp[i] + f[i-1]*(va[i-l-1]) % P) % P;
	cdq(m+1,r);
}
void pre_solve(){
	GetWn();
	init();
	dp[0] = 1;
	cdq(0,100000);
}

int main(){
	//FIN;
	pre_solve();int n;
	while(~scanf("%d",&n)) printf("%lld\n",dp[n]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Lee_w_j__/article/details/81988071