题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5322
题目大意:对于1~n的全排列中的任意一个排列,对于在排列中的任意一个 i,如果存在一个离 i 最近的 j 满足 i < j 且 A[i] < A[j],就在 i 和 j 之间建一条无向边。照此规矩建完边之后,如果联通图内的点数为P,那么这一个联通图的贡献就是P*P,种排列的贡献就为所有联通图的贡献的乘积。现在给你一个n,要你求出n的所有排列的贡献之和。
题目思路:由于这个题目的组数特别多,所以我们考虑用dp预处理来解决这个问题。
dp[i] 表示 长度为 i 的排列的贡献之和。
我们考虑对于长度为 i 的序列中的某一个序列当 最大的数 i 位于第 j 位时,前 j - 1 个数必然会与 i 形成联通块,后面的 i - j 个数就不会与前面的点联通。这样我们就可以推出如下的状态转移方程:
表示从除了最大值 i 以外的 i - 1 个数中选出 j - 1 一个放到 前 j - 1个位置里,同时这个j - 1个数总共有 (j - 1) ! 种排列方式,这个联通块对答案的贡献就为 j^2,再乘上后面 i - j 个数自己形成联通块的贡献之和就是最终答案了。
我们接着再用组合数的性质化简一下这个式子
接着令,
式子就化成了。
这个式子就能用FFT来求卷积了(本题由于有模数,所以可以选择用NTT来降低精度的误差)。
但由于是要预处理出所有的dp的值,所以我们可以选择用分治来降低复杂度,这样总的复杂度就为O(n*logn*logn+T)。
具体实现看代码:
#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define debug(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int,int>pii;
const int MX = 1e5+7;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
const int P = 998244353;
const int G = 3;
const int NUM = 20;
ll wn[NUM];
ll va[MX<<2],vb[MX<<2];
ll quick_mod(ll a, ll x, ll mod) {
ll ans = 1;
a %= mod;
while(x) {
if(x & 1)ans = ans * a % mod;
x >>= 1;
a = a * a % mod;
}
return ans;
}
//在程序的开头就要放
void GetWn() {
for(int i = 0; i < NUM; i++) {
int t = 1 << i;
wn[i] = quick_mod(G, (P - 1) / t, P);
}
}
void Rader(ll F[], int len) {
int j = len >> 1;
for(int i = 1; i < len - 1; i++) {
if(i < j) swap(F[i], F[j]);
int k = len >> 1;
while(j >= k)j -= k,k >>= 1;
if(j < k) j += k;
}
}
void NTT(ll F[], int len, int t) {
Rader(F, len);
int id = 0;
for(int h = 2; h <= len; h <<= 1) {
id++;
for(int j = 0; j < len; j += h) {
ll E = 1;
for(int k = j; k < j + h / 2; k++) {
ll u = F[k];
ll v = E * F[k + h / 2] % P;
F[k] = (u + v) % P;
F[k + h / 2] = (u - v + P) % P;
E = E * wn[id] % P;
}
}
}
if(t == -1) {
for(int i = 1; i < len / 2; i++)swap(F[i], F[len - i]);
ll inv = quick_mod(len, P - 2, P);
for(int i = 0; i < len; i++)F[i] = F[i] * inv % P;
}
}
void Conv(ll a[], ll b[], int len) {
NTT(a, len, 1);
NTT(b, len, 1);
for(int i = 0; i < len; i++)a[i] = a[i] * b[i] % P;
NTT(a, len, -1);
}
ll dp[MX],f[MX],invf[MX];
void init(){
f[0] = f[1] = 1;
for(int i = 2;i < MX;i++) f[i] = (f[i-1] * i) % P;
invf[MX-1] = quick_mod(f[MX-1],P-2,P);
for(int i = MX-2;i >= 0; i--) invf[i] = (invf[i+1] * (i+1)) % P;
}
void cdq(int l,int r){
if(l == r) return;
int m = (l + r) >> 1;
cdq(l,m);
int mx = (r - l + 1),len = 1;
while(len <= mx) len <<= 1;
for(int i = 0;i < len;i++){
if(l + i <= m) va[i] = dp[l+i]*invf[l+i] % P;
else va[i] = 0;
if(l + i < r) vb[i] = (ll)(i+1)*(i+1) % P;
else vb[i] = 0;
}
Conv(va,vb,len);
for(int i = m + 1;i <= r;i++)
dp[i] = (dp[i] + f[i-1]*(va[i-l-1]) % P) % P;
cdq(m+1,r);
}
void pre_solve(){
GetWn();
init();
dp[0] = 1;
cdq(0,100000);
}
int main(){
//FIN;
pre_solve();int n;
while(~scanf("%d",&n)) printf("%lld\n",dp[n]);
return 0;
}