多项式模板

无聊写了一下。

暂时写到多点求值。

#include<bits/stdc++.h>
#define RAN(a)a.begin(),a.end()
#define pb push_back
using namespace std;
typedef unsigned u32;
typedef unsigned long long u64;
const u32 p=1811939329;
u32 imod(u32 a){
	return a<p?a:a-p;
}
u32 ipow(u32 a,u32 n){
	u32 s=1;
	for(;n;n>>=1){
		if(n&1)
			s=(u64)s*a%p;
		a=(u64)a*a%p;
	}
	return s;
}
u32 iinv(u32 a){
	return ipow(a,p-2);
}
u32 gen(int n,int f=0){
	u32 g=ipow(13,(p-1)/n);
	if(f&1)
		g=iinv(g);
	return g;
}
int len(int n){
	while(n^n&-n)
		n+=n&-n;
	return n;
}
struct poly{
	vector<u32>a;
	u32&operator[](int i){
		return a[i];
	}
	const u32&operator[](int i)const{
		return a[i];
	}
	int size()const{
		return a.size();
	}
	u32 val(u32 x)const{
		u32 s=0;
		for(int i=a.size()-1;~i;--i)
			s=((u64)s*x+a[i])%p;
		return s;
	}
	u32 operator()(u32 x)const{
		return val(x);
	}
	void fix(){
		while(a.size()&&!a.back())
			a.pop_back();
	}
	void mod(int n){
		a.resize(n);
	}
	void fft(int n,int f){
		a.resize(n);
		if(n<=1)
			return;
		for(int i=0,j=0;i<n;++i){
			if(i<j)
				swap(a[i],a[j]);
			int k=n>>1;
			while((j^=k)<k)
				k>>=1;
		}
		vector<u32>w(n/2);
		w[0]=1;
		for(int i=1;i<n;i<<=1){
			for(int j=i/2-1;~j;--j)
				w[j<<1]=w[j];
			u64 g=gen(i<<1,f);
			for(int j=1;j<i;j+=2)
				w[j]=g*w[j-1]%p;
			for(int j=0;j<n;j+=i<<1){
				u32*b=&a[0]+j,*c=b+i;
				for(int k=0;k<i;++k){
					u32 v=(u64)w[k]*c[k]%p;
					c[k]=imod(b[k]+p-v);
					b[k]=imod(b[k]+v);
				}
			}
		}
	}
	void dft(int n){
		fft(n,0);
	}
	void idft(){
		int n=a.size();
		fft(n,1);
		u64 f=iinv(n);
		for(int i=0;i<n;++i)
			a[i]=f*a[i]%p;
	}
	void inv(int n){
		int m=len(n);
		vector<u32>c(m);
		for(int i=0;i<n&&i<a.size();++i)
			c[i]=a[i];
		a.assign(1,iinv(c[0]));
		for(int i=2;i<=m;i<<=1){
			int l=i<<1;
			poly b={vector<u32>(l)};
			for(int j=0;j<i;++j)
				b[j]=c[j];
			b.dft(l);
			a.resize(l);
			dft(l);
			for(int j=0;j<l;++j)
				a[j]=a[j]*(2+p-(u64)a[j]*b[j]%p)%p;
			idft();
			mod(i);
		}
		mod(n);
	}
	void mul(poly b){
		int n=len(a.size()+b.size()-1);
		dft(n);
		b.dft(n);
		for(int i=0;i<n;++i)
			a[i]=(u64)a[i]*b[i]%p;
		idft();
		fix();
	}
	void div(poly b){
		fix();
		int n=a.size()-b.size()+1;
		if(n<=0)
			return a.clear();
		reverse(RAN(a));
		fix();
		reverse(RAN(b.a));
		b.fix();
		b.inv(n);
		mul(b);
		a.resize(n);
		reverse(RAN(a));
		fix();
	}
	void mod(poly b){
		if(a.size()>=b.size()){
			poly c={a};
			c.div(b);
			c.mul(b);
			for(int i=0;i<b.size()-1;++i)
				a[i]=imod(a[i]+p-c[i]);
		}
		mod(b.size()-1);
	}
	void val(u32*l,u32*r)const{
		if(r-l<=pow(log(a.size()+1),2)+10){
			for(;l<r;++l)
				*l=val(*l);
			return;
		}
		if(r-l>a.size()/2){
			u32*m=l+(r-l)/2;
			val(l,m);
			val(m,r);
			return;
		}
		const int lim=500;
		class{
			vector<poly>f;
			int pre(u32*l,u32*r){
				int k=f.size();
				f.pb(poly());
				if(r-l<=lim){
					vector<u32>&a=f[k].a;
					a.assign(1,1);
					for(;l<r;++l){
						a.insert(a.begin(),0);
						for(int j=0;j<a.size()-1;++j)
							a[j]=(a[j]+(u64)(p-*l)*a[j+1])%p;
					}
				}else{
					u32*m=l+(r-l)/2;
					int i=pre(l,m);
					int j=pre(m,r);
					f[k]=f[i];
					f[k].mul(f[j]);
				}
				return k;
			}
			void sol(u32*l,u32*r,poly a){
				a.mod(f.back());
				f.pop_back();
				if(r-l<=lim)
					for(;l<r;++l)
						*l=a.val(*l);
				else{
					u32*m=l+(r-l)/2;
					sol(l,m,a);
					sol(m,r,a);
				}
			}
		public:
			void operator()(u32*l,u32*r,const poly&a){
				pre(l,r);
				reverse(RAN(f));
				sol(l,r,a);
			}
		}sol;
		sol(l,r,*this);
	}
};

  

猜你喜欢

转载自www.cnblogs.com/f321dd/p/9363493.html