[UOJ500]任意基DFT

题目

传送门 to UOJ

思路

多点求值是 O ( q log ⁡ 2 n ) \mathcal O(q\log^2 n) O(qlog2n) 的,无法通过。要利用点值的特殊性。题面里声称是因为 “输入量较大”,放屁嘞!

q n = x q n − 1 + y    ( n ∈ N + ) q_n=xq_{n-1}+y\;(n\in\N^+) qn=xqn1+y(nN+),很容易求出通项公式
q n = x n − 1 x − 1 ⋅ y + x n q 0 q_n={x^n-1\over x-1}\cdot y+x^nq_0 qn=x1xn1y+xnq0
注意到变量只有 x n x^n xn,所以把它简化为 k x n + b kx^n+b kxn+b 。具体值就自己对应一下就好了。

这时候我们会想要将这个值直接代入,来观察式子的性质。设 f ( x ) = ∑ a i x i f(x)=\sum a_ix^i f(x)=aixi,则
f ( k x m + b ) = ∑ i , j a i ( i j ) ⋅ k j x m j b i − j = ∑ j x m j ⋅ k j j ! ∑ i ( i ! ⋅ a i ) ⋅ b i − j ( i − j ) ! f(kx^m+b)=\sum_{i,j}a_i{i\choose j}\cdot k^jx^{mj}b^{i-j}\\ =\sum_{j}x^{mj}\cdot{k^j\over j!}\sum_{i}(i!\cdot a_i)\cdot{b^{i-j}\over(i-j)!} f(kxm+b)=i,jai(ji)kjxmjbij=jxmjj!kji(i!ai)(ij)!bij
注意到 x m j x^{mj} xmj 说白了就是 ( x m ) j (x^m)^j (xm)j,而右侧又是只与 j j j 有关的系数,所以我们会考虑直接求出系数 v j = k j j ! ∑ i ( i ! ⋅ a i ) ⋅ b i − j ( i − j ) ! v_j=\frac{k^j}{j!}\sum_{i}(i!\cdot a_i)\cdot{b^{i-j}\over(i-j)!} vj=j!kji(i!ai)(ij)!bij,显然这是 O ( n log ⁡ n ) \mathcal O(n\log n) O(nlogn) 一次卷积就可以完成的。记 g ( x ) = ∑ j v j x j g(x)=\sum_{j}v_jx^j g(x)=jvjxj,接下来只需求出 g ( x m ) g(x^m) g(xm)

利用这里讲的 Bluestain \text{Bluestain} Bluestain 算法,用 m j = ( m + j 2 ) − ( m 2 ) − ( j 2 ) mj={m+j\choose 2}-{m\choose 2}-{j\choose 2} mj=(2m+j)(2m)(2j) 就可凑出卷积形式。时间复杂度 O [ ( n + q ) log ⁡ ( n + q ) ] \mathcal O[(n+q)\log(n+q)] O[(n+q)log(n+q)]

代码

#include <cstdio> // Dangerous Dark Ghost!!!
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
typedef long long llong;
inline int readint(){
    
    
	int a = 0, c = getchar(), f = 1;
	for(; !isdigit(c); c=getchar())
		if(c == '-') f = -f;
	for(; isdigit(c); c=getchar())
		a = (a<<3)+(a<<1)+(c^48);
	return a*f;
}
void writeint(unsigned x){
    
    
	if(x > 9) writeint(x/10);
	putchar(char((x%10)^48));
}

const int MOD = 998244353, LOGMOD = 30;
inline int modAdd(int a,int b){
    
    
	return (a += b) >= MOD ? (a -= MOD) : a;
}
inline void modAddUp(int &a,int b){
    
    
	if((a += b) >= MOD) a -= MOD;
}
inline int qkpow(llong b,int q){
    
    
	llong a = 1;
	for(; q; q>>=1,b=b*b%MOD)
		if(q&1) a = a*b%MOD;
	return static_cast<int>(a);
}

int g[LOGMOD], inv2[LOGMOD];
void prepare(){
    
    
	int p = MOD-1, x = 0; inv2[0] = 1;
	for(inv2[1]=(MOD+1)>>1; !(p&1); p>>=1,++x)
		inv2[x+1] = int(llong(inv2[x])*inv2[1]%MOD);
	for(g[x]=qkpow(3,p); x; --x)
		g[x-1] = int(llong(g[x])*g[x]%MOD);
}

const int MAXN = 3000005;
void NTT(int a[],int n){
    
    
	const int *end_a = a+(1<<n);
	for(int w=1<<n>>1,x=n; w; w>>=1,--x)
	for(int *p=a; p!=end_a; p+=(w<<1))
	for(int i=0,v=1; i!=w; ++i){
    
    
		const llong l = p[i];
		modAddUp(p[i],p[i+w]);
		p[i+w] = int(llong(l+MOD-p[i+w])*v%MOD);
		v = int(llong(v)*g[x]%MOD);
	}
}
void DNTT(int a[],int n){
    
    
	const int *end_a = a+(1<<n);
	for(int w=1,x=1; x<=n; w<<=1,++x)
	for(int *p=a; p!=end_a; p+=(w<<1))
	for(int i=0,v=1; i!=w; ++i){
    
    
		const int t = int(llong(p[i+w])*v%MOD);
		p[i+w] = modAdd(p[i],MOD-t);
		modAddUp(p[i],t); // ordinary
		v = int(llong(v)*g[x]%MOD);
	}
	std::reverse(a+1,a+(1<<n));
	for(int *i=a; i!=end_a; ++i)
		*i = int(llong(*i)*inv2[n]%MOD);
}

int a[MAXN], tmp[MAXN], inv[MAXN];
int main(){
    
    
	prepare(); // DON'T FORGET THIS!
	int n = readint(), q = readint();
	rep(i,(inv[1]=1)+1,n)
		inv[i] = int(llong(MOD-MOD/i)*inv[MOD%i]%MOD);
	rep(i,0,n) a[i] = readint();
	int g0 = readint();
	int gx = readint(), gy = readint();

	int b = int(llong(gy)*qkpow(gx-1,MOD-2)%MOD);
	int k = b+g0; if(b) b = MOD-b;
	int N = 1; for(; (1<<N)<=(n<<1); ++N);
	for(int i=0,jc=1; i<=n; ++i){
    
    
		a[i] = int(llong(jc)*a[i]%MOD);
		jc = int(llong(jc)*(i+1)%MOD); // i!
	}
	for(int i=n-(tmp[n]=1); ~i; --i)
		tmp[i] = int(llong(tmp[i+1])
			*inv[n-i]%MOD*b%MOD);
	NTT(a,N), NTT(tmp,N); // mother f**ker polymul
	for(int i=0; i!=(1<<N); ++i)
		a[i] = int(llong(a[i])*tmp[i]%MOD);
	DNTT(a,N); memmove(a,a+n,(n+1)<<2); // memmove

	for(; (1<<N)<=(n<<1)+q; ++N); // bigger
	const int invgx = qkpow(gx,MOD-2);
	for(int i=0,v=1,av=1,avv=1; i<=n; ++i){
    
    
		a[i] = int(llong(v)*a[i]%MOD*avv%MOD);
		v = int(llong(v)*inv[i+1]%MOD*k%MOD);
		avv = int(llong(avv)*av%MOD); // x^{-i*(i-1)/2}
		av = int(llong(av)*invgx%MOD); // x^{-i}
	}
	std::reverse(a,a+n+1); // reverse
	memset(a+n+1,0,((1<<N)-n-1)<<2);
	for(int i=0,v=1,av=1; i<=n+q; ++i){
    
    
		tmp[i] = av; // x^{i*(i-1)/2}
		av = int(llong(av)*v%MOD);
		v = int(llong(v)*gx%MOD);
	}
	memset(tmp+n+q+1,0,((1<<N)-q-n-1)<<2);
	NTT(a,N), NTT(tmp,N);
	for(int i=0; i!=(1<<N); ++i)
		a[i] = int(llong(a[i])*tmp[i]%MOD);
	DNTT(a,N); memmove(a,a+n,(q+1)<<2);
	for(int i=0,v=1,av=1; i<=q; ++i){
    
    
		a[i] = int(llong(a[i])*av%MOD);
		av = int(llong(av)*v%MOD);
		v = int(llong(v)*invgx%MOD);
	}
	int xyx = 0; // xor sum
	for(int i=1; i<=q; ++i)
		xyx ^= a[i]; // answer
	printf("%d\n",xyx);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_42101694/article/details/122278629
dft