「LibreOJ NOI Round #2」不等关系——容斥DP+分治NTT

题目

「LibreOJ NOI Round #2」不等关系

题解

首先,序列列中既有"<"又有">"肯定不好处理,如果忽略">"的限制,序列就被分为由"<"连接的若干个联通块(当然,每个块是一条数值递增的链),

设第i个链的大小为siz_i套一个不用搞懂的公式答案为\frac{n!}{\prod siz_i!}

然后就可以用容斥,设Ans_i表示把 i 个“>”改为“<”、其它“>”忽略的情况下的方案数,最终答案是:\sum_{i=0}(-1)^iAns_i

所以我们的朴素做法就是枚举每一种修改情况的总贡献,然后配上系数加起来,

但是显然太暴力了,于是把它拍扁在一次遍历中,就得到如下dp:

设dp[i]表示前i个点的容斥总答案,有dp[i]=(-1)^{cnt_{i-1}}\sum_{j=0}^{i-1}[s[j]='>']dp[j](-1)^{cnt_j}\frac{1}{(i-j)!},其中cnt_i表示前i个字符中">"的个数。

观察式子,发现结构很典型,可以用分治NTT做。

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#define ll long long
#define MAXN 140005
using namespace std;
inline ll read(){
	ll x=0;bool f=1;char s=getchar();
	while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
	while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
	return f?x:-x;
}
char s[MAXN];
int n,cnt[MAXN],dp[MAXN],fac[MAXN],inv[MAXN];
int a[MAXN<<1],b[MAXN<<1],c[MAXN<<1],d[MAXN<<1];

#define mod 998244353ll
#define g 3ll
int rev[MAXN<<1];
inline ll ksm(ll a,ll b,ll mo){
	ll res=1;
	for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
	return res;
}
inline int NTT(int*a,int N,int inv){
	int bit=0,n=N;
	while((1<<bit)<n)bit++;n=(1<<bit);
	for(int i=0;i<n;i++){
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
		if(i<rev[i])swap(a[i],a[rev[i]]);
	}
	for(int mid=1;mid<n;mid<<=1){
		ll tmp=ksm(g,(mod-1)/(mid<<1),mod),om=1,y;
		if(inv<0)tmp=ksm(tmp,mod-2,mod);
		for(int i=0;i<n;i+=(mid<<1),om=1)
			for(int j=0;j<mid;j++,om=(om*tmp)%mod)
				y=om*a[i+j+mid]%mod,a[i+j]=(y+a[i+j])%mod,a[i+j+mid]=((mod-y<<1)+a[i+j])%mod;
	}
    if(inv<0)for(int i=0,in=ksm(n,mod-2,mod);i<n;i++)a[i]=1ll*in*a[i]%mod;
	return n;
}

inline void CDQ_NTT(int l,int r){    //分治NTT原理是CDQ分治
	if(l==r){
		dp[l]=1ll*dp[l]*b[l]%mod;
		return;
	}
	int mid=(l+r)>>1;
	CDQ_NTT(l,mid);
	for(int i=l;i<=mid;i++)c[i-l]=1ll*dp[i]*a[i]%mod;
	for(int i=l;i<=r;i++)d[i-l]=inv[i-l];
	for(int i=min(MAXN<<1,r-l<<2)-2;i>mid-l;i--)c[i]=0;
	for(int i=min(MAXN<<1,r-l<<2)-2;i>r-l;i--)d[i]=0;
	int N=NTT(d,r-l<<1,1);NTT(c,r-l<<1,1);
	for(int i=0;i<N;i++)c[i]=1ll*c[i]*d[i]%mod;
	NTT(c,r-l<<1,-1);
	for(int i=mid+1;i<=r;i++)dp[i]=(0ll+c[i-l]+dp[i])%mod;
	CDQ_NTT(mid+1,r);
}

signed main()
{
	scanf("%s",s+1),n=strlen(s+1)+1;
	fac[0]=1;
	for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
	inv[n]=ksm(fac[n],mod-2,mod);
	for(int i=n-1;i>=0;i--)
		inv[i]=(1ll+i)*inv[i+1]%mod;
	dp[0]=a[0]=b[0]=1;
	for(int i=1;i<=n;i++)
		b[i]=a[i-1],a[i]=(s[i]=='<'?a[i-1]:mod-a[i-1]);
	for(int i=1;i<n;i++)if(s[i]=='<')a[i]=0;
	CDQ_NTT(0,n);
	printf("%d\n",1ll*dp[n]*fac[n]%mod);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43960287/article/details/113820192