多项式多点求值的小常数解法

M U L T ( f ( z ) , g ( z ) ) = i = 0 z i j = 0 ( [ z i + j ] f ( z ) ) ( [ z j ] g ( z ) ) MUL^T(f(z),g(z)) = \sum_{i=0} z^i\sum_{j=0} ([z^{i+j}]f(z))([z^j]g(z))
也就是差卷积。
可以发现
F ( x 0 ) = i = 0 n x 0 i [ x i ] F ( x ) = [ x 0 ] M U L T ( F ( x ) , 1 1 x 0 z ) F(x_0) =\sum_{i=0}^n x_0^i[x^i]F(x) =[x^0]MUL^T(F(x) , \frac 1{1-x_0z})
又因为 M U L T ( f ( z ) , g ( z ) h ( z ) ) = M U L T ( M U L T ( f ( z ) , g ( z ) ) , h ( z ) ) MUL^T(f(z),g(z)h(z)) = MUL^T(MUL^T(f(z),g(z)),h(z))
所以我们可以求出 i = 1 m ( 1 x i z ) \prod_{i=1}^{m} (1-x_iz) 后,
求出 M U L T ( F ( z ) , 1 ( 1 x i z ) ) MUL^T(F(z),\prod \frac 1{(1-x_iz)})
然后在线段树上往下走,如果往 [ l , m i d ] [l,mid] 走就需要把当前的多项式 F ( z ) F(z) 给变成 M U L T ( F ( z ) , i = m i d + 1 r 1 1 x 0 z ) MUL^T(F(z),\prod_{i=mid+1}^r\frac 1{1-x_0z})
只需要卷积。

#include<bits/stdc++.h>
#define maxn 300005
#define mod 998244353
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
using namespace std;

int Wl,W[maxn],lg[maxn],inv[maxn],r[maxn];
int Pow(int b,int k){ int r=1;for(;k;k>>=1,b=1ll*b*b%mod) if(k&1) r=1ll*r*b%mod; return r; }
void init(int n){
	for(W[0]=inv[0]=inv[1]=Wl=1;n>=Wl<<1;Wl<<=1);int pw=Pow(3,(mod-1)/Wl/2);
	rep(i,1,Wl<<1) W[i]=1ll*W[i-1]*pw%mod,(i>1)&&(inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod,lg[i]=lg[i>>1]+1);
}
void NTT(int *A,int n,int tp){
	rep(i,0,n-1) i<(r[i]=(r[i>>1]>>1)|((i&1)<<(lg[n]-1)))&&(swap(A[i],A[r[i]]),0);
	for(int L=1,B=Wl;L<n;L<<=1,B>>=1) for(int s=0;s<n;s+=L<<1) for(int k=s,x=0,t;k<s+L;k++,x+=B)
		t=1ll*(tp==1?W[x]:W[(Wl<<1)-x])*A[k+L]%mod,A[k+L]=(A[k]-t)%mod,A[k]=(A[k]+t)%mod;
	if(tp^1) rep(i,0,n-1) A[i]=1ll*A[i]*inv[n]%mod;
}
void MUL(int *A,int *B,int *C,int n,int m){
	static int t[2][maxn];int L=1<<lg[n+m]+1;
	rep(i,0,L-1) t[0][i]=i<=n?A[i]:0,t[1][i]=i<=m?B[i]:0;NTT(t[0],L,1),NTT(t[1],L,1);
	rep(i,0,L-1) C[i]=1ll*t[0][i]*t[1][i]%mod;NTT(C,L,-1);
}
void INV(int *A,int *B,int n){
	B[B[1]=0]=Pow(A[0],mod-2);static int t[maxn];
	for(int k=2,L=4;k<(n<<1);k<<=1,L<<=1){
		rep(i,0,L-1) t[i]=i<k?A[i]:B[i]=0;NTT(B,L,1),NTT(t,L,1);
		rep(i,0,L-1) B[i]=B[i]*(2-1ll*B[i]*t[i]%mod)%mod;NTT(B,L,-1);
		rep(i,min(n,k),L-1) B[i]=0;
	}
}
void DIV(int *A,int *B,int *C,int *R,int n,int m){
	if(n<m){ rep(i,0,m-1) R[i]=A[i];return; }
	reverse(A,A+n+1),reverse(B,B+m+1),INV(B,C,n-m+1);
	MUL(A,C,C,n-m,n-m),fill(C+n-m+1,C+2*n-2*m+1,0);
	reverse(A,A+n+1);reverse(B,B+m+1);reverse(C,C+n-m+1);
	MUL(B,C,R,m,n-m);rep(i,0,n) R[i]=(A[i]-R[i])%mod;
}
#define lc u<<1
#define rc u<<1|1
int *M[maxn],*F[maxn],dM[maxn];
#define NAF(a) new int[1<<lg[a]+2]
void BDM(int u,int l,int r,int *X){
	M[u]=NAF(r-l+1),dM[u]=r-l+1;
	if(l==r) return (void)(M[u][0]=-X[l],M[u][1]=1);
	int m=(l+r)>>1;BDM(lc,l,m,X),BDM(rc,m+1,r,X);
	MUL(M[lc],M[rc],M[u],dM[lc],dM[rc]);
}
void EVL(int u,int l,int r,int *C){
	if(u>1) DIV(F[u>>1],M[u],F[0],F[u]=NAF(dM[u>>1]),dM[u>>1]-1,dM[u]);
	if(l==r) return (void)(C[l]=F[u][0]);
	int m=(l+r)>>1;EVL(lc,l,m,C),EVL(rc,m+1,r,C);
}
void MEV(int *A,int *X,int *C,int n,int m){//0...m-1
	BDM(1,0,m-1,X);
	DIV(A,M[1],F[0]=NAF(max(n,m)),F[1]=NAF(max(n,m)),n,dM[1]);
	EVL(1,0,m-1,C);
}

int n,m,A[maxn],X[maxn],C[maxn];

int main(){
	scanf("%d%d",&n,&m);init(2*max(n,m));
	rep(i,0,n) scanf("%d",&A[i]);
	rep(i,0,m-1) scanf("%d",&X[i]);
	MEV(A,X,C,n,m);
	rep(i,0,m-1) printf("%d\n",(C[i]+mod)%mod);
}

猜你喜欢

转载自blog.csdn.net/qq_35950004/article/details/107647353