CF 1323D Present

url

给一个序列 \(a\),求 \(\prod(a_i+a_j)\),中间不是乘法而是异或连接

\(2 \leq n \leq 400\,000,1 \leq a_i \leq 10^7\)

按二进制位进行计算,对于每一位单独考虑

假设考虑第 \(x\)

如果现在有一个数 \(a_i\) 在第 \(x\) 位为 1,则对于其他第 \(x\) 位为 1 的数 \(a_j\) 来说,如果 \((1<<x) \le a_i \% (1<<x) + a_j \% (1<<x)\) 则对该位有一次贡献。另外有其他三种情况,分别是进位不进位、当前位为 0 / 1

用树状数组维护即可,时间复杂度 \(O(nlog(n)log(max(a_i)))\),空间复杂度 \(O(max(a_i))\)

#include <bits/stdc++.h>
#define ll long long
#define X first
#define Y second
#define sz size()
#define all(x) x.begin(), x.end()
using namespace std;

typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<long long> vl;

template <class T>
inline bool scan(T &ret){
    char c;
    int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF
    while (c != '-' && (c < '0' || c > '9')) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}

const ll mod = 1e9+7;
const int maxn = 4e7+50;
const int inf = 0x3f3f3f3f;
const double eps = 1e-8;

ll qp(ll x, ll n) {
    ll res = 1; x %= mod;
    while (n > 0) {
        if (n & 1) res = res * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return res;
}

bool pre[2][maxn];
int n, a[(int)4e5+50], ct[(int)4e5+50];

inline void add(bool pre[], int x, int mx) {
    x ++, mx ++;
    while (x <= mx) {
        pre[x] ^= 1;
        x += x & -x;
    }
}

inline bool query(bool pre[], int x) {
    x ++;
    bool res = 0;
    while (x > 0) {
        res ^= pre[x];
        x -= x & -x;
    }
    return res;
}

int main(int argc, char* argv[]) {
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i) {
        scanf("%d", &a[i]);
    }
    int res = 0;
    for (int bit = 0; bit < 26; ++bit) {
        bool count = 0;
        int now = 1<<bit;
        bool c0 = 0, c1 = 0;
        for (int i = 1; i <= n; ++i) {
            ct[i] = a[i] % now;
            int ict = now - ct[i] - 1;
            if (a[i] >> bit & 1) {
                count ^= query(pre[0], ict);
                count ^= c1 ^ query(pre[1], ict);
                add(pre[1], ct[i], now);
                c1 ^= 1;
            } else {
                count ^= c0 ^ query(pre[0], ict);
                count ^= query(pre[1], ict);
                add(pre[0], ct[i], now);
                c0 ^= 1;
            }
        }
        for (int i = 1; i <= n; ++i) {
            if (a[i] >> bit & 1) {
                add(pre[1], ct[i], now);
            } else {
                add(pre[0], ct[i], now);
            }
        }
        if (count) res |= now;
    }
    printf("%d\n", res);
    return 0;
}

优化版使用 \(bitset\) 对树状数组压位处理,空间少一个 log

#include <bits/stdc++.h>
#pragma comment(linker, "/STACK:102400000,102400000")
#define ll long long
#define X first
#define Y second
#define sz size()
#define all(x) x.begin(), x.end()
using namespace std;

typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<long long> vl;

template <class T>
inline bool scan(T &ret){
    char c = 0;
    while (c < '0' || c > '9') c = getchar();
    ret = c - '0';
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    return 1;
}

const ll mod = 1e9+7;
const int maxn = 3.5e7+50;
const int inf = 0x3f3f3f3f;
const double eps = 1e-8;

ll qp(ll x, ll n) {
    ll res = 1; x %= mod;
    while (n > 0) {
        if (n & 1) res = res * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return res;
}

static bitset<maxn> pre[2];
int n, a[(int)4e5+50], ct[(int)4e5+50];

inline void add(bitset<maxn> &pre, int x, int mx) {
    x ++, mx ++;
    while (x <= mx) {
        pre[x] = pre[x] ^ 1;
        x += x & -x;
    }
}

inline bool query(bitset<maxn> &pre, int x) {
    x ++;
    bool res = 0;
    while (x > 0) {
        res ^= pre[x];
        x -= x & -x;
    }
    return res;
}

int main(int argc, char* argv[]) {
    scan(n);
//    scanf("%d", &n);
    for(int i = 1; i <= n; ++i) {
        scan(a[i]);
//        scanf("%d", &a[i]);
    }
    int res = 0;
    for (int bit = 0; bit < 25; ++bit) {
        bool count = 0;
        int now = 1<<bit;
        bool c0 = 0, c1 = 0;
        for (int i = 1; i <= n; ++i) {
            ct[i] = a[i] % now;
            int ict = now - ct[i] - 1;
            if (a[i] >> bit & 1) {
                count ^= query(pre[0], ict);
                count ^= c1 ^ query(pre[1], ict);
                add(pre[1], ct[i], now);
                c1 ^= 1;
            } else {
                count ^= c0 ^ query(pre[0], ict);
                count ^= query(pre[1], ict);
                add(pre[0], ct[i], now);
                c0 ^= 1;
            }
        }
        pre[0] = pre[1] = 0;
        if (count) res |= now;
    }
    printf("%d\n", res);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/badcw/p/12437166.html