2019牛客暑期多校训练营(第一场)- H. XOR(线性基)

题目链接:XOR

题意:给定n个整数,求满足子集异或和为0的子集大小之和

思路:先求出n个整数的线性基r,线性基的大小为cr,讨论每个元素对答案的贡献

  • 线性基r外的元素共有n-cr个,对于每个元素,都能够与其他n-cr-1个线性基外的元素组合,组合后一定能在r内找到唯一的对应元素,所以每个元素对答案的贡献为$2^{n-cr-1}$种,有n-cr个元素,所以总共的贡献为$(n-cr)*2^{n-cr-1}$
  • 扫描一遍线性基r内的每个元素,每次去掉第i个元素,对剩下的n-1元素求线性基d,线性基的大小为cd,如果第i个元素还能插入线性基d,则一定不能异或出0,否则他能够与其他n-cd-1个线性基外的元素组合,对答案的贡献为$2^{n-cd-1}$

对于求线性基d,我们可以先对没有在r内的元素求一个线性基b,每次将线性基b和去掉第i个元素的线性基r进行合并即可

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>

using namespace std;

typedef long long ll;

const int N = 100010;
const int M = 70;
const ll mod = 1000000007;

int n, vis[N];
ll a[N], r[M], b[M], d[M];
vector<ll> v;

ll power(ll a, int n, ll p)
{
    ll res = 1;
    while (n) {
        if (n & 1) res = (res * a) % p;
        a = (a * a) % p;
        n >>= 1;
    }
    return res % p;
}

bool insert(ll x, ll b[])
{
    for (int k = 63; k >= 0; k--) {
        if (x >> k & 1) {
            if (b[k]) x ^= b[k];
            else {
                b[k] = x;
                return true;
            }
        }
    }
    return false;
}

int main()
{
    while (scanf("%d", &n) != EOF) {
        int cr = 0;
        v.clear();
        for (int i = 0; i < M; i++) r[i] = b[i] = 0;
        for (int i = 1; i <= n; i++) {
            vis[i] = 0;
            scanf("%lld", &a[i]);
            if (insert(a[i], r)) {
                cr++;
                vis[i] = 1;
                v.push_back(a[i]);
            }
        }
        if (cr == n) {
            printf("0\n");
            continue;
        }
        for (int i = 1; i <= n; i++) {
            if (vis[i]) continue;
            insert(a[i], b);
        }
        ll res = (n - cr) * power(2, n - cr - 1, mod) % mod;
        int len = (int)v.size();
        for (int i = 0; i < len; i++) {
            int cd = 0;
            for (int k = 0; k <= 63; k++) d[k] = 0;
            for (int k = 0; k < len; k++) {
                if (i == k) continue;
                if (insert(v[k], d)) cd++;
            }
            for (int k = 0; k <= 63; k++)
                if (b[k] && insert(b[k], d)) cd++;
            if (!insert(v[i], d))
                res = (res + power(2, n - cd - 1, mod)) % mod;
        }
        printf("%lld\n", res);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zzzzzzy/p/12373696.html