K 进制 FWT 学习笔记

需要快速求出 C k = i j = k A i B j C_k=\sum_{i\bigotimes j=k}A_i*B_j
考虑构造一个可逆的矩阵满足 ( f A ) ( f B ) = ( f C ) (f*A)*(f*B)=(f*C)
那么只需要快速求出 f A , f B , f 1 ( f C ) f*A,f*B,f^{-1}*(f*C) 即可

推导矩阵的系数:
( f A ) [ x ] ( f B ) [ x ] = ( f C ) [ x ] (f*A)[x]*(f*B)[x]=(f*C)[x]
i = 0 n f x , i a i i = 0 n f x , i b i = i = 0 n f x , i c i \sum_{i=0}^nf_{x,i}*a_i*\sum_{i=0}^nf_{x,i}b_i=\sum_{i=0}^nf_{x,i}c_i
需要满足
f x , k = i j = k f x , i f x , j f_{x,k}=\sum_{i\bigotimes j=k}f_{x,i}*f_{x,j}
这样的矩阵显然太大了

每一位都是相同的,考虑构造一个 k k 位的矩阵
设一个数 x 的 m m 位分别是 x 0 , x 1 , . . . , x m 1 x_0,x_1,...,x_{m-1}
g x , y = i x i y i g_{x,y}=\prod_i x_iy_i ,这样子矩阵的大小就是 k
考虑如果知道 g g 如何求 f A , f B f*A,f*B
x ^ \hat x 为 x 的最高位, x x' 为 x 除去最高位后剩下的
( f A ) [ x ] = i = 0 m 1 f x , i a i (f*A)[x]=\sum_{i=0}^{m-1}f_{x,i}a_i
= i = 0 k 1 g x ^ , i j ^ = i f x , j a j =\sum_{i=0}^{k-1}g_{\hat x,i}\sum_{\hat j=i}f_{x',j'}a_j
= i = 0 k 1 f x ^ , i ( f A ) [ x ] =\sum_{i=0}^{k-1} f_{\hat x,i}(f*A')[x']
A A' 就是原序列把最高位是 i i 的那一段截出来的序列
然后就可以通过分治求得这一个东西

以下定义卷积符号是 k 进制的不进位加法
如果要让 f f 满足条件,也就是每一位都满足条件,对应到 g g 上面就是
g x , k = K i + j k g x , i g x , j g_{x,k}=\sum_{K|i+j-k}g_{x,i}*g_{x,j}
想到了单位根反演,并且循环有意义,那么可以按如下方式构造

g = 1 1 1 . . . 1 w K 1 w K 2 . . . 1 w K 2 w K 4 . . . . . . . . . . . . . . . 1 w K K 1 w K 2 ( K 1 ) . . . g=\begin{matrix} 1 & 1 & 1&... \\ 1 & w_K^1 & w_K^2 &...\\ 1 & w_K^2 & w_K^4 & ... \\ ... &...&...&...\\ 1 & w_K^{K-1} & w_K^{2*(K-1)} & ...\end{matrix}

g 1 = 1 K 1 1 1 . . . 1 w K 1 w K 2 . . . 1 w K 2 w K 4 . . . . . . . . . . . . . . . 1 w K ( K 1 ) w K 2 ( K 1 ) . . . g^{-1}=\frac{1}{K}\begin{matrix} 1 & 1 & 1&... \\ 1 & w_K^{-1} & w_K^{-2} &...\\ 1 & w_K^{-2} & w_K^{-4} & ... \\ ... &...&...&...\\ 1 & w_K^{-(K-1)} & w_K^{-2*(K-1)} & ...\end{matrix}
【清华集训2016】石家庄的工人阶级队伍比较坚强

#include<bits/stdc++.h>
#define cs const
using namespace std;
int read(){
	int cnt = 0, f = 1; char ch = 0;
	while(!isdigit(ch)){ ch = getchar(); if(ch == '-') f = -1; }
	while(isdigit(ch)) cnt = cnt*10 + (ch-'0'), ch = getchar();
	return cnt * f;
}
cs int N = 1e6 + 5, M = 15;
int Mod, n, m, T, pw[M];
int add(int a, int b){ return a + b >= Mod ? a + b - Mod : a + b; }
int mul(int a, int b){ return 1ll * a * b % Mod; }
int dec(int a, int b){ return a - b < 0 ? a - b + Mod : a - b; }
int ksm(int a, int b){ int ans = 1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans=mul(ans,a); return ans; }
void exgcd(int a, int b, int &x, int &y){
	if(!b){ x = 1; y = 0; return; } exgcd(b,a-a/b*b,x,y); int t=x; x=y;y=t-a/b*y;
}
int inv(int a){ int x,y,b=Mod; exgcd(a,b,x,y); return (x%b+b)%b; }

struct data{
	int x, y;
	data(int _x = 0, int _y = 0){ x = _x, y = _y; }
	data operator + (cs data &a){ return data(add(x,a.x),add(y,a.y)); }
	data operator - (cs data &a){ return data(dec(x,a.x),dec(y,a.y)); }
	data operator * (cs data &a){ return data(dec(mul(x,a.x),mul(y,a.y)), dec(add(mul(x,a.y), mul(y,a.x)), mul(y,a.y))); }
	bool operator < (cs data &a)cs{ return x^a.x ? x<a.x : y<a.y;} 
}w[3], f[N], g[N];
data power(data a, int b){ data ans(1,0); for(;b;b>>=1,a=a*a) if(b&1) ans=ans*a; return ans; }
int b[M][M];
void dfs(int u, int vl, int ct1, int ct2){
	if(u == m){ g[vl].x = b[ct1][ct2]; return; }
	dfs(u+1, vl, ct1, ct2);
	dfs(u+1, vl+pw[u], ct1+1, ct2);
	dfs(u+1, vl+pw[u]+pw[u], ct1, ct2+1);
}
void dft(data *a, int typ){
	for(int i = 1; i < n; i *= 3){
		for(int j = 0, len = i * 3; j < n; j += len){
			for(int k = 0; k < i; k++){
				data a0 = a[j + k], a1 = a[j + k + i], a2 = a[j + k + i + i];
				a[j + k] = a0 + a1 + a2;
				a[j + k + i] = a0 + a1 * w[1] + a2 * w[2];
				a[j + k + i + i] = a0 + a1 * w[2] + a2 * w[1];
				if(typ == -1) swap(a[j + k + i], a[j + k + i + i]);
			}
		}
	}
}
map<data, data> S;
data calc(data x){ if(S.count(x)) return S[x];return S[x] = power(x, T); }
int main(){
	scanf("%d%d%d", &m, &T, &Mod);
	pw[0] = 1; for(int i = 1; i <= m; i++) pw[i] = pw[i-1] * 3;
	w[0] = data(1, 0); w[1] = data(0, 1); w[2] = data(Mod-1,Mod-1);
	n = pw[m];
	for(int i = 0; i < n; i++) f[i].x = read();
	for(int i = 0; i <= m; i++) for(int j = 0; j <= m-i; j++) b[i][j] = read();
	dfs(0, 0, 0, 0);

	dft(f, 1); dft(g, 1);
	for(int i = 0; i < n; i++) f[i] = f[i] * calc(g[i]);
	dft(f, -1); int iv = inv(n);
	for(int i = 0; i < n; i++) cout << mul(iv, f[i].x) << '\n';
	return 0;
}
发布了610 篇原创文章 · 获赞 94 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/sslz_fsy/article/details/103319235
FWT