hdu6057 Kanade's convolution 【FWT】

题目链接

hdu6057

题意

给出序列\(A[0...2^{m} - 1]\)\(B[0...2^{m} - 1]\),求所有
\[C[k] = \sum\limits_{i \; and \; j = k} A[i \; xor \; j]B[i \; or \; j]\]

题解

我只能感叹太神了
看到题目我是懵逼的

首先注意三者运算的关系:
\[(i \; and \; j) + (i \; xor \; j) = (i \; or \; j)\]
证明显然
于是我们枚举\(x = i \; or \; j,y = i \; xor \; j\),显然\(y \in x\)\(x \; and \; y = y\),且对于同一个\(x,y\),这样的\(i,j\)存在\(2^{bit(y)}\)对,\(bit(y)\)\(y\)二进制下\(1\)的个数
证明显然

于是我们有
\[ \begin{aligned} C[k] &= \sum\limits_{i \; and \; j = k} A[i \; xor \; j]B[i \; or \; j] \\ &= \sum\limits_{x - y = k} [x \; and \; y = y]B[x]A[y]2^{bit(y)} \\ &= \sum\limits_{x \; xor \; y = k} [bit(x) - bit(y) = bit(k)]B[x]A[y]2^{bit(y)} \\ \end{aligned} \]
出去中间那个限制,就是一个异或卷积了
考虑如何去掉中间的限制,我们只需将\(bit()\)不同的位置分离,分别做\(FWT\)
即设\(F(A,x)_{i} = [bit(i) = x]A_i\)
那么有
\[F(C,k) = \sum\limits_{i = k}^{m} F(B,i) \times F(A,i - k)\]
然后\(C[k]\)的结果就存在\(F(C,bit(k))\)

复杂度\(O(m^2 2^{m})\)

#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#define LL long long int
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define cls(s,v) memset(s,v,sizeof(s))
#define mp(a,b) make_pair<int,int>(a,b)
#define cp pair<int,int>
using namespace std;
const int maxn = (1 << 19),maxm = 100005,INF = 0x3f3f3f3f;
inline int read(){
    int out = 0,flag = 1; char c = getchar();
    while (c < 48 || c > 57){if (c == '-') flag = 0; c = getchar();}
    while (c >= 48 && c <= 57){out = (out << 1) + (out << 3) + c - 48; c = getchar();}
    return flag ? out : -out;
}
const int P = 998244353;
int m,A[21][maxn],B[21][maxn],C[21][maxn],a[maxn],b[maxn],inv2,deg;
inline int qpow(int a,int b){
    int re = 1;
    for (; b; b >>= 1,a = 1ll * a * a % P)
        if (b & 1) re = 1ll * re * a % P;
    return re;
}
inline int bit(int x){int re = 0; while (x) re += (x & 1),x >>= 1; return re;}
inline void fwt(int* a,int n,int f){
    for (int i = 1; i < n; i <<= 1)
        for (int j = 0; j < n; j += (i << 1))
            for (int k = 0; k < i; k++){
                int x = a[j + k],y = a[j + k + i];
                a[j + k] = (x + y) % P,a[j + k + i] = (x - y + P) % P;
                if (f == -1) a[j + k] = 1ll * a[j + k] * inv2 % P,a[j + k + i] = 1ll * a[j + k + i] * inv2 % P;
            }
}
int main(){
    inv2 = qpow(2,P - 2);
    m = read(); deg = (1 << m); int x;
    for (int i = 0; i < deg; i++){
        a[i] = read(); x = bit(i);
        A[x][i] = 1ll * a[i] * qpow(2,x) % P;
    }
    for (int i = 0; i < deg; i++){
        b[i] = read();
        B[bit(i)][i] = b[i];
    }
    for (int i = 0; i <= m; i++){
        fwt(A[i],deg,1);
        fwt(B[i],deg,1);
    }
    for (int k = 0; k <= m; k++){
        for (int x = k; x <= m; x++)
            for (int i = 0; i < deg; i++)
                C[k][i] = (C[k][i] + 1ll * B[x][i] * A[x - k][i] % P) % P;
    }
    for (int i = 0; i <= m; i++) fwt(C[i],deg,-1);
    int ans = 0,tmp = 1;
    for (int i = 0; i < deg; i++)
        ans = (ans + 1ll * C[bit(i)][i] * tmp % P) % P,tmp = 1ll * tmp * 1526 % P;
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Mychael/p/9257928.html
FWT