[2018 集训队互测 Day 5]LOJ 2504 小 H 爱染色 - 拉格朗日插值 - NTT

题解:通过列式子发现答案是关于n-m的3m+1次多项式,为了求出F(0)~F(3m+1),可以发现这玩意可以NTT出来。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define p 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
namespace INPUT_SPACE{
	const int BS=(1<<24)+5;char Buffer[BS],*HD,*TL;inline int gc() { if(HD==TL) TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);return (HD==TL)?EOF:*HD++; }
	inline int inn() { int x,ch;while((ch=gc())<'0'||ch>'9');x=ch^'0';while((ch=gc())>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^'0');return x; }
}using INPUT_SPACE::inn;
const int N=3000100;
int fac[N],facinv[N],inv[N],v[N],cb[N],pre[N],suf[N];
inline int sol(int x,int s) { return (s&1)?(x?p-x:0):x; }
inline int squ(int x) { return (lint)x*x%p; }
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int *newInt(int n) { int *t=new int[n];memset(t,0,sizeof(int)*n);return t; }
inline int clr(int *a,int n) { return memset(a,0,sizeof(int)*n),0; }
inline int cpy(int *a,int *b,int n) { return memcpy(a,b,sizeof(int)*n),0; }
inline int prelude(int n)
{
	rep(i,fac[0]=1,n) fac[i]=(lint)fac[i-1]*i%p;
	facinv[n]=fast_pow(fac[n],p-2);
	for(int i=n-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%p;
	rep(i,1,n) inv[i]=(lint)fac[i-1]*facinv[i]%p;
	return 0;
}
namespace IVS_space{
	int *p0,*p1,ivs[N];
	inline int init_ivs(int n,int u)
	{
		if(n<=u) { rep(i,0,n) ivs[i]=inv[n-i];return 0; }
		p0=newInt(u+1),p1=newInt(u+1);
		p0[0]=n;rep(i,1,u) p0[i]=(lint)p0[i-1]*(n-i)%p;
		p1[u]=fast_pow(p0[u],p-2),ivs[0]=fast_pow(n,p-2);
		for(int i=u-1;i>=0;i--) p1[i]=p1[i+1]*(n-i-1ll)%p;
		rep(i,1,u) ivs[i]=(lint)p0[i-1]*p1[i]%p;
		return delete p0,delete p1,0;
	}
}using IVS_space::ivs;
namespace NTT_space{
	const int N=4194304+100;
	int *dwg[25],*dwgi[25],*r,*A,*B;
	inline int prelude_dwg()
	{
		int n=N-10;
		for(int i=2,c=1;i<=n;i<<=1,c++)
		{
			dwg[c]=newInt(i>>1),dwgi[c]=newInt(i>>1);
			int *d=dwg[c],*di=dwgi[c];d[0]=di[0]=1;
			int w=fast_pow(3,(p-1)/i),wn=fast_pow(3,p-1-(p-1)/i);
			rep(j,1,(i>>1)-1) d[j]=(lint)d[j-1]*w%p,di[j]=(lint)di[j-1]*wn%p;
		}
		return 0;
	}
	inline int NTT(int *a,int n,int s)
	{
		rep(i,1,n-1) if(i<r[i]) swap(a[i],a[r[i]]);
		for(int i=2,c=1;i<=n;i<<=1,c++)
		{
			int *d=(s>0?dwg[c]:dwgi[c]);
			for(int j=0,t=i>>1,v,x,y;j<n;j+=i) rep(k,0,t-1)
				x=a[j+k],y=(lint)d[k]*a[j+k+t]%p,
				a[j+k]=((v=x+y)>=p?v-p:v),a[j+k+t]=((v=x-y)<0?v+p:v);
		}
		if(s<0) for(int i=0,ninv=fast_pow(n,p-2);i<n;i++) a[i]=(lint)a[i]*ninv%p;
		return 0;
	}
	inline int tms(int *a,int m1,int *b,int m2,int *c,int m3)
	{
		int n=1,L=0;while(n<m1+m2-1) n<<=1,L++;
		r=newInt(n),A=newInt(n),B=newInt(n);
		rep(i,1,n-1) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
		cpy(A,a,m1),clr(A+m1,n-m1),NTT(A,n,1);
		cpy(B,b,m2),clr(B+m2,n-m2),NTT(B,n,1);
		rep(i,0,n-1) A[i]=(lint)A[i]*B[i]%p;
		return NTT(A,n,-1),cpy(c,A,m3),delete A,delete B,0;
	}
}
namespace F_space{
	int m,u,f[N],*a,*b;
	inline int prelude(int _m,int _u)
	{
		m=_m,u=_u,NTT_space::prelude_dwg(),a=newInt(u+1),b=newInt(u+1);
		rep(i,0,m) a[i]=sol((lint)f[i]*facinv[i]%p*facinv[m-i]%p,m-i);
		rep(i,1,u) b[i]=inv[i];NTT_space::tms(a,m+1,b,u+1,a,u+1);
		rep(i,m+1,u) f[i]=(lint)fac[i]*facinv[i-m-1]%p*a[i]%p;
		return delete a,delete b,0;
	}
	inline int F(int n) { return f[n]; }
}
int main()
{
	int n=inn(),m=inn(),u=3*m+1;prelude(u);
	rep(i,0,m) F_space::f[i]=inn();F_space::prelude(m,u);
	cb[0]=facinv[m];rep(i,0,m-1) cb[0]=(lint)cb[0]*(n-i)%p;
	IVS_space::init_ivs(n,u);rep(i,1,u+1) cb[i]=cb[i-1]*(n-i-m+1ll)%p*ivs[i-1]%p;
	rep(i,0,u) v[i]=(lint)F_space::F(i)*(squ(cb[i])-squ(cb[i+1]))%p,(v[i]<0?v[i]+=p:0);
	rep(i,1,u) v[i]+=v[i-1],(v[i]>=p?v[i]-=p:0);
	n-=m;if(n<=u) return !printf("%d\n",v[n]);pre[0]=n%p,suf[u]=(n-u)%p;int ans=0;
	for(int i=1;i<=u;i++) pre[i]=(lint)pre[i-1]*(n-i)%p;
	for(int i=u-1;i>=0;i--) suf[i]=(lint)suf[i+1]*(n-i)%p;
	rep(i,0,u) ans+=sol((lint)v[i]*(i>0?pre[i-1]:1)%p*(i<u?suf[i+1]:1)%p*facinv[i]%p*facinv[u-i]%p,u-i),(ans>=p?ans-=p:0);
	return !printf("%d\n",ans);
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/88534875