UOJ#428. 【集训队作业2018】普通的计数题(牛顿迭代)

传送门

题解:
把0操作看做是叶子,1操作看做非叶节点,一个操作在另一个操作删除,则另一个操作为这个操作的父亲,于是转化成了满足以下条件的 n n 个点的树的计数:
1.父亲标号>儿子。
2.若一个点为非叶节点,记其儿子中叶子节点的数量为 T T ,则若其儿子中有非叶节点, T A T\in A ,否则 T B T \in B
首先可以发现的是 B B 集合中有没有0都无所谓(因为必须选非空序列),所以假设其0次项为1。

然后就类似无根树计数的方法(这里注意大小为1的是有序的,其余是无序的),把他的儿子给拼出来,我们可以枚举有几个非叶节点作为儿子,则:
f n = [ n 1 B ] + i [ i A ] k 1 k ! i = 1 k a i = n 1 i , a i 2 ( n 1 ) ! i ! a i ! i f a i f_n = [n-1 \in B] + \sum_{i}[i \in A]\sum_{k}\frac{1}{k!}\sum_{\sum_{i=1}^k a_i=n-1-i,a_i \ge 2} \frac{(n-1)!}{i!\prod a_i!}\prod_i{f_{a_i}}

两边除个 n ! n! ,得到:
n f n n ! = [ n 1 B ] ( n 1 ) ! + i [ i A ] i ! k 1 k ! j = 1 k a i = n 1 i , a i 2 f a i a i ! n\frac{f_n}{n!} = \frac{[n-1 \in B]}{(n-1)!} + \sum_{i}\frac{[i \in A]}{i!} \sum_{k}\frac{1}{k!}\sum_{\sum_{j=1}^k a_i=n-1-i,a_i \ge 2}\prod{\frac{f_{a_i}}{a_i!}}

不妨设 F = i = 1 f i i ! x i F=\sum_{i=1}^{\infty}\frac{f_i}{i!}x^i ,可以得到:
F = B + A ( e F x 1 ) = A e x e F + ( B A ) \begin{matrix} F' &=& B+A(e^{F-x}-1) \\ &=& Ae^{-x}e^{F}+(B-A) \end{matrix}

C = A e x , D = B A C=Ae^{-x},D=B-A ,化简一下:
F = C e F + D F' =Ce^{F}+D

我们只需要解这个方程就行啦!

怎么解呢,可以用牛顿迭代法,先假设一下我们知道 f 0 = f   m o d   x n 2 f_0 = f \bmod{x^\frac{n}{2}} ,考虑怎么求 f   m o d   x n f \bmod {x^n}

可以得到:
F C e f 0 + D + C e f 0 ( f f 0 ) ( m o d x n ) F' \equiv Ce^{f_0}+D+Ce^{f_0}(f-f_0) \pmod {x^n}

整理一下常量,发现其实是要解这个方程:
F = T F + Z F' = TF+Z

这个可以先设 U = T U U'=TU ,解出 U U :
d U d x = T U U 1 d U = T d x ln U = T d x U = e T d x \begin{aligned} \frac{dU}{dx} = TU \\ U^{-1}dU=Tdx \\ \ln U=\int Tdx\\ U= e^{\int Tdx} \end{aligned}

然后设 V = U 1 F V=U^{-1}F ,则:
( U V ) = T U V + Z U V + U V = U V + Z V = Z U \begin{aligned} (UV)'=TUV+Z \\ U'V+UV'=U'V+Z \\ V=\int \frac{Z}{U} \end{aligned}

然后 F = U V F=UV ,这道题就做完辣,时间复杂度 O ( n log n ) O(n \log n)

#include <bits/stdc++.h>
using namespace std;

const int RLEN=1<<18|1;
inline char nc() {
	static char ibuf[RLEN],*ib,*ob;
	(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
	return (ib==ob) ? -1 : *ib++;
}
inline int rd() {
	char ch=nc(); int i=0,f=1;
	while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
	while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
	return i*f;
}

const int N=1e6+50, mod=998244353;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline int mul(int x,int y) {return (long long)x*y%mod;}
inline int power(int a,int b,int rs=1) {for(;b;b>>=1,a=mul(a,a)) if(b&1) rs=mul(rs,a); return rs;}
inline int sgn(int x) {return (x&1) ? (mod-1) : 1;}

namespace FFT {
	const int G=3;
	int A[N],B[N],w[N],pos[N],k;
	inline void init(int n) {
		for(k=1;k<=n;k<<=1);
		memset(A,0,sizeof(int)*k);
		memset(B,0,sizeof(int)*k);
		for(int i=1;i<k;i++) pos[i]=(i&1) ? ((pos[i>>1]>>1)^(k>>1)) : (pos[i>>1]>>1);
	}
	inline void dft(int *a) {
		for(int i=1;i<k;i++)
			if(pos[i]>i) swap(a[pos[i]],a[i]);
		for(int bl=1;bl<k;bl<<=1) {
			int tl=bl<<1, wn=power(G,(mod-1)/tl);
			w[0]=1; for(int i=1;i<bl;i++) w[i]=mul(w[i-1],wn);
			for(int bg=0;bg<k;bg+=tl)
				for(int j=0;j<bl;j++) {
					int &t1=a[bg+j], &t2=a[bg+j+bl], t=mul(t2,w[j]);
					t2=dec(t1,t); t1=add(t1,t);
				}
		}
	}
	inline void func() {
		dft(A); dft(B);
		for(int i=0;i<k;i++) B[i]=mul(B[i],A[i]);
		dft(B); const int inv=power(k,mod-2);
		for(int i=0;i<k;i++) B[i]=mul(B[i],inv);
		reverse(B+1,B+k); 
	}
}
struct combin {
	int fac[N],ifac[N];
	combin() {
		fac[0]=1;
		for(int i=1;i<N;i++) fac[i]=mul(fac[i-1],i);
		ifac[0]=ifac[1]=1;
		for(int i=2;i<N;i++) ifac[i]=mul(mod-mod/i,ifac[mod%i]);
		for(int i=2;i<N;i++) ifac[i]=mul(ifac[i-1],ifac[i]); 
	}
	inline int inv(int i) {return mul(ifac[i],fac[i-1]);}
} cb;

struct poly {
	vector <int> a;
	poly(int d=0,int t=0) {a.resize(d+1); a[d]=t;}
	inline int& operator [](const int &i) {return a[i];}
	inline const int& operator [](const int &i) const {return a[i];}
	inline int deg() const {return a.size()-1;}
	inline poly extend(int k) {poly c=*this; c.a.resize(k); return c;}
	friend inline poly operator +(const poly &a,const poly &b) {
		poly c(max(a.deg(),b.deg()),0);
		for(int i=0;i<=a.deg();i++) c[i]=add(c[i],a[i]);
		for(int i=0;i<=b.deg();i++) c[i]=add(c[i],b[i]);
		return c;
	}
	friend inline poly operator -(const poly &a,const poly &b) {
		poly c(max(a.deg(),b.deg()),0);
		for(int i=0;i<=a.deg();i++) c[i]=add(c[i],a[i]);
		for(int i=0;i<=b.deg();i++) c[i]=dec(c[i],b[i]);
		return c;
	}
	friend inline poly operator *(const poly &a,const int &b) {
		poly c=a;
		for(int i=0;i<=c.deg();i++) c[i]=mul(c[i],b);
		return c;
	}
	friend inline poly operator *(const poly &a,const poly &b) {
		poly c(a.deg()+b.deg()); FFT::init(c.deg());
		for(int i=0;i<=a.deg();i++) FFT::A[i]=a[i];
		for(int i=0;i<=b.deg();i++) FFT::B[i]=b[i];
		FFT::func();
		for(int i=0;i<=c.deg();i++) c[i]=FFT::B[i];
		return c;
	}
	inline poly dg() {
		if(!deg()) return poly(0,0);
		poly c(deg()-1,0);
		for(int i=0;i<=c.deg();i++)
			c[i]=mul(a[i+1],i+1);
		return c;
	}
	inline poly ig() {
		poly c(deg()+1);
		for(int i=1;i<=c.deg();i++)
			c[i]=mul(a[i-1],cb.inv(i));
		return c;
	}
	inline poly inv(poly f,int k) {
		if(k==1) {return poly(0,power(f[0],mod-2));}
		poly f0=inv(f.extend(k>>1),k>>1); 
		return f0*2-(((f0*f0).extend(k))*f).extend(k);
	}
	inline poly ln(poly f,int k) {
		poly f0=f.dg(), f1=inv(f,k);
		return (f0*f1).ig().extend(k);
	}
	inline poly exp(poly f,int k) {
		if(k==1) {return poly(0,1);}
		poly f0=exp(f.extend(k>>1),k>>1);
		return (f0*(f-ln(f0,k)+poly(0,1))).extend(k);
	}
	inline poly cinv(int k) {
		return inv(*this,k);
	}
	inline poly cexp(int k) {
		return exp(*this,k);
	}
} ;

inline poly get_poly(poly C,poly D,int k) {
	if(k==1) {return poly(0,0);}
	poly f0=get_poly(C.extend(k>>1),D.extend(k>>1),k>>1);
	poly T=(C*f0.cexp(k)).extend(k);
	poly Z=T+D-(T*f0).extend(k);
	poly U=T.ig().cexp(k);
	poly V=(Z*U.cinv(k)).extend(k).ig();
	return (U*V).extend(k);
}
int n,a,b;
int main() {
	n=rd();
	poly A(n,0), B(n,0);
	int a=rd(), b=rd();
	for(int i=1;i<=a;i++) {
		int x=rd();
		A[x]=cb.ifac[x];
	} 
	for(int i=1;i<=b;i++) {
		int x=rd();
		B[x]=cb.ifac[x];
	} B[0]=1;
	poly C(n,0),D;
	for(int i=0;i<=n;i++)
		C[i]=mul(sgn(i),cb.ifac[i]);
	C=(A*C).extend(n+1);
	D=B-A;
	int k=1; for(;k<=n;k<<=1);
	poly f=get_poly(C,D,k);
	cout<<mul(f[n],cb.fac[n]);
}

发布了553 篇原创文章 · 获赞 227 · 访问量 24万+

猜你喜欢

转载自blog.csdn.net/qq_35649707/article/details/84316673