LOJ #150. 挑战多项式/多项式全家桶

题目

这个题堪称 简易模板全家桶

关于各个函数的实现,大多数都是利用牛顿迭代公式( x = x 0 f ( x ) f ( x ) x=x_0-\dfrac {f(x)}{f'(x)} )+倍增。

以下为多项式全家桶证明:

下面的 f , b , a f,b,a 分别为已知多项式,当前所求多项式和上一个状态(规模小一半)的多项式.(简记 y = a ( x ) y=a(x) )

sqrt \text{sqrt}

b ( x ) 2 = f ( x )    > b ( x ) 2 f ( x ) = 0 b(x)^2=f(x) ~~->b(x)^2-f(x)=0
φ ( b ( x ) ) = b ( x ) 2 f ( x ) 设\varphi(b(x))=b(x)^2-f(x)
φ ( b ( x ) ) = 2 b ( x ) 则\varphi'(b(x))=2*b(x)
, 此时,我们求函数零点
b ( x ) = a ( x ) φ ( a ( x ) ) φ ( a ( x ) ) = a ( x ) a 2 ( x ) f ( x ) 2 a ( x ) = a 2 ( x ) + f ( x ) 2 a ( x ) b(x)=a(x)-\dfrac{\varphi(a(x))}{\varphi'(a(x))}=a(x)-\dfrac{a^2(x)-f(x)}{2a(x)}=\dfrac{a^2(x)+f(x)}{2a(x)}

y 2 + f 2 y \dfrac {y^2+f}{2y}

inv \text{inv}

φ ( b ( x ) ) = b ( x ) f ( x ) 1 \varphi(b(x))=b(x)f(x)-1

φ ( b ( x ) ) = f ( x ) \varphi'(b(x))=f(x)

b ( x ) = a ( x ) a ( x ) f ( x ) 1 f ( x ) b(x)=a(x)-\dfrac{a(x)f(x)-1}{f(x)}

b ( x ) = a ( x ) b ( x ) ( a ( x ) f ( x ) 1 ) b(x)=a(x)-b(x)(a(x)f(x)-1)

由于 a ( x ) f ( x ) 1 0 ( m o d    x n / 2 ) a(x)f(x)-1\equiv 0(\mod x^{n/2}) .

所以 b ( x ) b(x) 的高 n / 2 n/2 位乘上 a ( x ) f ( x ) 1 a(x)f(x)-1 m o d    x n \mod x^n 意义下为0.

所以 b ( x ) b(x) a ( x ) a(x) 在当前等价.

b ( x ) = a ( x ) a ( x ) ( a ( x ) f ( x ) 1 ) = a ( x ) ( 2 a ( x ) f ( x ) ) b(x)=a(x)-a(x)(a(x)f(x)-1)=a(x)(2-a(x)f(x))

y ( 2 y f ) y(2-yf)

ln \ln

ln x = 1 x \ln 'x=\dfrac 1 x

b ( x ) = ln f ( x ) b(x)=\ln f(x)

对两边求导: b ( x ) = i n v ( f ( x ) ) f ( x ) b'(x)=inv(f(x))f'(x) .(链式反应)

积分 b ( x ) b'(x) 即可得到 b ( x ) b(x) .

f f \int \dfrac {f'}f

exp \exp

因为 exp , ln \exp,\ln 为逆运算,所以可得:

b ( x ) = exp f ( x ) ln b ( x ) = f ( x ) b(x)=\exp f(x)\rightarrow \ln b(x)=f(x)

φ ( b ( x ) ) = ln b ( x ) f ( x ) \varphi(b(x))=\ln b(x)-f(x)

φ ( b ( x ) ) = 1 b ( x ) \varphi'(b(x))=\dfrac 1{b(x)}

则可得到: b ( x ) = a ( x ) ( 1 ln a ( x ) + f ( x ) ) b(x)=a(x)(1-\ln a(x)+f(x)) .

y ( 1 ln y + f ) y(1-\ln y+f)

pow \text{pow}

f ( x ) k = e l n ( f ( x ) ) k f(x)^k=e^{ln(f(x))*k}

f(x)/g(x),f(x)mod g(x) \text{f(x)/g(x),f(x)mod g(x)}

多项式带余除法.

f ( x ) = q ( x ) g ( x ) + r ( x ) ( m o d    x n ) f(x)=q(x)g(x)+r(x)(\mod x^n) .

已知 f , g f,g q , r q,r . f f n n 次多项式, g g m m 次多项式, r r 的次数 < m <m .

可以发现 x n f ( x 1 ) x^n f(x^{-1}) 的系数为 f f 系数的翻转,我们简记这个多项式为 F ( x ) F(x) .

则有: F ( x ) = x n q ( x ) g ( x ) + x n r ( x ) = x n m q ( x ) x m g ( x ) + x n m + 1 x m 1 r ( x ) = Q ( x ) G ( x ) + x n m + 1 R ( x ) ( m o d    x n ) F(x)=x^n q(x)g(x)+x^nr(x)=x^{n-m} q(x) x^m g(x)+x^{n-m+1}x^{m-1}r(x)= Q(x)G(x)+x^{n-m+1}R(x)(\mod x^n)

我们可以发现 F ( x ) = Q ( x ) G ( x ) ( m o d    x n m + 1 ) F(x)=Q(x)G(x)(\mod x^{n-m+1}) .

这样求出 Q Q 后即可得到 q , r q,r .

代码:

#include<ctime>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define gc getchar()//(p1==p2&&(p2=(p1=buf)+fread(buf,1,N,stdin),p1==p2)?EOF:*p1++)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1<<22|10,mod=998244353;

char buf[N],*p1=buf,*p2=buf;
template<class o>void qr(o &x) {
	char c=gc; x=0;
	while(!isdigit(c))c=gc;
	while(isdigit(c))x=(x*10+c-'0')%mod,c=gc;
}
template<class o>void qw(o x) {
	if(x/10) qw(x/10);
	putchar(x%10+'0');
}
template<class o>void pr1(o x) {qw(x); putchar(' ');}
template<class o>void pr2(o x)  {qw(x); puts("");}

ll power(ll a,ll b=mod-2,ll p=mod) {
	ll c=1;
	while(b) {
		if(b&1) c=c*a%p;
		b /= 2; a=a*a%p;
	}
	return c;
}

namespace Cipolla {
	ll p=mod,w;
	struct CP {
		ll x,y;
		CP(ll a=1,ll b=0) {x=a; y=b;}
		CP operator *(CP b) const {return CP( ((x*b.x+y*b.y%p*w)%p+p)%p , ((x*b.y+y*b.x)%p+p)%p);}
	} ;
	CP power(CP a,ll b=(p+1)/2) {
		CP c;
		while(b) {
			if(b&1) c=c*a;
			b /= 2; a=a*a;
		}
		return c;
	}
	bool pd(ll x) {return ::power(x,(p-1)/2,p)==p-1;}
	ll solve(ll n) {
		if(n<=1) return n;
		ll a; do {
			a=(rand()<<15|rand())%p;
			w=((a*a-n)%p+p)%p;
		} while(!pd(w));
		ll x=power(CP(a,1)).x,y=p-x;
		return min(x,y);
	}
}

namespace P {
	const int g=3,inv2=(mod+1)/2;
	int R[N],w[N],Inv[N];
	int calc(int x) {if((x&-x)==x) return x; int n=1; while(n<x) n*=2; return n;}//输入长度 
	void init(int m) {
		int n=calc(m)*4; 
		Inv[1]=1; for(int i=2;i<n;i++) Inv[i]=(ll)Inv[mod%i]*(mod-mod/i)%mod;
		for(int i=1;i<n;i*=2) {//枚举半区间长度,把对应的单位根填入w数组 
			ll t=power(g,(mod-1)/(2*i)),d=1;
			for(int j=0;j<i;j++) w[i+j]=d,d=d*t%mod;
		}
	}
	int pre(int m) {//输入总长度 
		int n=calc(m);
		for(int i=1;i<n;i++) R[i]=(R[i>>1]>>1)|(i&1?n>>1:0);
		return n;
	}
	void upd(int &x) {x+=x>>31&mod;}
	void DFT(int *f,int n) {
		static ull p[N];
		for(int i=0;i<n;i++) p[R[i]]=f[i];
		for(int i=1,t;i<n;i*=2) for(int j=0;j<n;j+=2*i)
			for(int k=0;k<i;k++) t=p[j+k+i]*w[i+k]%mod,p[j+k+i]=p[j+k]+mod-t,p[j+k]+=t;
		for(int i=0;i<n;i++) f[i]=p[i]%mod;
	}
	void IDFT(int *f,int n) {
		reverse(f+1,f+n); DFT(f,n); ll inv=power(n);
		for(int i=0;i<n;i++) f[i]=inv*f[i]%mod;
	}
	void copy(int *a,int *b,int n) {memcpy(a,b,sizeof(int[n]));}
	void clear(int *a,int len) {memset(a+len,0,sizeof(int[len]));}
	void clear(int *a,int x,int y) {if(x<y) memset(a+x,0,sizeof(int[y-x]));}
	void dao(int *a,int *b,int n) {
		for(int i=1;i<n;i++) b[i-1]=(ll)a[i]*i%mod;
		b[n-1]=0;
	}
	void ji(int *a,int *b,int n) {
		for(int i=n-1; i;i--) b[i]=(ll)a[i-1]*Inv[i]%mod;
		b[0]=0;
	}
	void mult(int *a,int *b,int n,int m) {
		static int c[N];
		int x=pre(n+m);
		clear(a,n,x); copy(c,b,m); clear(c,m,x);
		DFT(a,x); DFT(c,x);
		for(int i=0;i<x;i++) a[i]=(ll)a[i]*c[i]%mod;
		IDFT(a,x);
	}
	int h[N];
	void getinv(int *a,int *b,int n) {// 
		clear(b,0,2*n); clear(a,n); clear(h,0,2*n); b[0]=power(a[0]);
		for(int p=2;p<=n;p*=2) {
			int x=pre(p*2);
			copy(h,a,p); clear(h,p); DFT(h,x); DFT(b,x);
			for(int i=0;i<x;i++) b[i]=(ll)(2-(ll)b[i]*h[i]%mod+mod)*b[i]%mod;
			IDFT(b,x); clear(b,p);
		}
	}
	void getsqrt(int *a,int *b,int n) {
		static int c[N],f[N];
		clear(c,0,2*n); clear(f,0,2*n); clear(a,n); clear(b,0,2*n);
		b[0]=Cipolla::solve(a[0]); b[1]=0; 
		for(int p=2;p<=n;p*=2) {
			copy(f,a,p); clear(f,p); getinv(b,c,p); mult(f,c,p,p);
			for(int i=0;i<p;i++) b[i]=(ll)(b[i]+f[i])*inv2%mod;
		}
	}
	void getln(int *a,int *b,int n) {// 
		getinv(a,b,n); dao(a,h,n);
		mult(h,b,n,n); ji(h,b,n); clear(h,0,2*n);
	}
	void getexp(int *a,int *b,int n) {
		static int c[N]; clear(b,0,2*n); clear(a,n); b[0]=1;
		for(int p=2;p<=n;p*=2) {
			copy(c,b,p/2); clear(c,p/2); getln(c,b,p);upd(--b[0]);
			for(int i=0;i<p;i++) upd(b[i]=-b[i]+a[i]);
			mult(b,c,p,p); clear(b,p);
		}
	}
	void getdiv(int *a,int *b,int *c,int n,int m) {//c=a(n)/b(m). 
		static int h1[N],h2[N]; 
		int len=n-m+1,x=calc(len);
		copy(h1,a,n); copy(h2,b,m); 
		clear(h1,n,x); clear(h2,m,x);
		reverse(h1,h1+n); reverse(h2,h2+m); 
		getinv(h2,c,x); mult(c,h1,x,x); reverse(c,c+len); clear(c,len);
	}
	void getmod(int *a,int *b,int *c,int n,int m) {
		static int d[N];
		getdiv(a,b,c,n,m);
		copy(d,b,m);mult(c,d,n-m+1,m);
		for(int i=0;i<m;i++) upd(c[i]=a[i]-c[i]);
	}
}

int n,m,t,f[N],g[N],h[N];

int main() {
	qr(n); qr(m); n++; t=P::calc(n); P::init(t);
	for(int i=0;i<n;i++) qr(f[i]);
	P::getsqrt(f,g,t);
	P::getinv(g,h,t);
	P::ji(h,g,t);
	P::getexp(g,h,t);
	for(int i=0;i<t;i++) P::upd(g[i]=f[i]-h[i]);
	P::upd(g[0]+=2); P::upd(g[0]-=f[0]);
	P::getln(g,h,t); P::upd(h[0]+=1-mod);
	P::getln(h,g,t); for(int i=0;i<t;i++) g[i]=(ll)g[i]*m%mod; 
	P::getexp(g,h,t); P::dao(h,f,n); n--;
	for(int i=0;i<n;i++) pr1(f[i]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_42886072/article/details/106823704