[SOJ628] 基础卷积练习题

题意简述:有两个下标范围在\([0,2^n)\),值域为\([0, m)\)的整数序列\(a, b\)。定义\(c_i=\max_{j\operatorname{xor} k=i} f(a_j, b_k)\),其中\(f(x, y)\)是定义域和值域均为[0,m)的整数的二元函数,且\(f(x, y)\)的值均给定,求\(c\)\(n, m\leq 16\)

我会\(\operatorname{fwt}\)

我假了!\(\max\)不可减!因此\(\operatorname{ifwt}\)的时候会咕咕咕!

我会暴力拿\(10\)分!

注意到值域很小,因此我们可以强行维护\([0, m)\)每个数出现的次数。也就是说,我们用长度为\(m\)的序列在\(f(x, y)\)下的卷积代替普通整数乘法,用序列按位加减法代替普通整数加减法。

严谨的说,设\(a, b, c\)为长为\(m\)的整数序列,\(c=a\ast b \Leftrightarrow c_i=\sum_{f(j, k)=i} a_j\cdot b_k\)\(c=a\pm b \Leftrightarrow c_i=a_i\pm b_i\)

这样就可以愉快的\(\operatorname{fwt}\)啦!最后得到的\(c_i\)是一个整数序列,\(c_{i, j}\)代表\(j\)出现的次数,只要\(c_{i, j}\)不为\(0\),就会对\(\max\)产生贡献。复杂度\(O(n\cdot 2^n\cdot m+2^n \cdot m^2)\)

#include <cstdio>
#include <cctype>
#include <cstring>
#include <cassert>
#include <iostream>
#include <algorithm>
#define R register
#define ll long long
using namespace std;
const int N = 1 << 16, M = 16;

int n, m, lim, f[M][M];
struct node {
    int c[M];
    node() {
        memset(c, 0, sizeof (c));
        return;
    }
    inline node operator * (const node &x) const {
        node ret;
        for (R int i = 0; i < m; ++i)
            for (R int j = 0; j < m; ++j)
                ret.c[f[i][j]] += c[i] * x.c[j];
        return ret;
    }
    inline node operator + (const node &x) const {
        node ret;
        for (R int i = 0; i < m; ++i)
            ret.c[i] = c[i] + x.c[i];
        return ret;
    }
    inline node operator - (const node &x) const {
        node ret;
        for (R int i = 0; i < m; ++i)
            ret.c[i] = c[i] - x.c[i];
        return ret;
    }
}a[N], b[N], c[N];

template <class T> inline void read(T &x) {
    x = 0;
    char ch = getchar(), w = 0;
    while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
    while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    x = w ? -x : x;
    return;
}

void fwt(node *a, int lim, int opt) {
    int n = 1 << lim;
    for (R int l = 2; l <= n; l <<= 1)
        for (R int m = l >> 1, i = 0; i < n; i += l)
            for (R int j = i; j < i + m; ++j)
                if (opt == 1)
                    a[j + m] = a[j + m] + a[j];
                else
                    a[j + m] = a[j + m] - a[j];
    return;
}

int main() {
    int x;
    read(n), read(m), lim = 1 << n;
    for (R int i = 0; i < lim; ++i)
        read(x), a[i].c[x] = 1;
    for (R int i = 0; i < lim; ++i)
        read(x), b[i].c[x] = 1;
    for (R int i = 0; i < m; ++i)
        for (R int j = 0; j < m; ++j)
            read(f[i][j]);
    fwt(a, n, 1), fwt(b, n, 1);
    for (R int i = 0; i < lim; ++i)
        c[i] = a[i] * b[i];
    fwt(c, n, -1);
    for (R int i = 0; i < lim; ++i) {
        for (R int j = m - 1; ~j; --j)
            if (c[i].c[j]) {
                printf("%d ", j);
                break;
            }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suwakow/p/11640852.html