9.13 异或序列

题意

给定\(a,b\)数组,定义\(c_i=a_i\oplus b_i\),其中\(\oplus\)运算是异或运算

请合理安排\(a,b\)的顺序使得得到的\(c\)的字典序最小,并输出\(c\)


解法

第一次做\(01Trie\)的问题,感觉好神奇

\(a\)数组与\(b\)数组分别构建两颗\(01Trie\)(用二进制位构造)

在两颗\(Trie\)上遍历,贪心的走边

能走全\(0\)或全\(1\)就贪心的走,否则就分开走

由于是从高位到低位的走\(Trie\)边,所以当前取得的一定是最优的

判断当前子树内有无未匹配的串可以记录一个子树大小,每次匹配后进行更新


代码

#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 2e5 + 10;

int n, top;
int pw[N], ans[N];

int read() {
    int c = getchar(), x = 0;
    while (c < '0' || c > '9')  c = getchar();
    while (c >= '0' && c <= '9')    x = x * 10 + c - 48, c = getchar(); 
    return x;
}

struct Trie {
    
    int root, cnt;
    int sz[N * 30];
    
    struct node {
        int ch[2];
    } t[N * 30];
    
    Trie() : root(1), cnt(1) {}
    
    void ins(int x) {
        int p = root;
        for (int i = 30; i >= 0; --i) {
            int k = (pw[i] & x) ? 1 : 0;
            sz[p]++;
            if (!t[p].ch[k])
                t[p].ch[k] = ++cnt;
            p = t[p].ch[k];
        }
        sz[p]++;
    }
    
} a, b;

void DFS(int x, int y, int val, int dep) {
//  printf("%d %d %d %d\n", x, y, val, dep);
    if (dep == -1) {
        while (a.sz[x] && b.sz[y]) {
            a.sz[x]--, b.sz[y]--;
            ans[++top] = val;
        }
        return;
    }
    
    int ls_a = a.t[x].ch[0], rs_a = a.t[x].ch[1];
    int ls_b = b.t[y].ch[0], rs_b = b.t[y].ch[1];
    
    if (a.sz[ls_a] && b.sz[ls_b])
        DFS(ls_a, ls_b, val, dep - 1);
    if (a.sz[rs_a] && b.sz[rs_b])
        DFS(rs_a, rs_b, val, dep - 1);
    if (a.sz[ls_a] && b.sz[rs_b])
        DFS(ls_a, rs_b, val | pw[dep], dep - 1);
    if (a.sz[rs_a] && b.sz[ls_b])   
        DFS(rs_a, ls_b, val | pw[dep], dep - 1);
    
    a.sz[x] = a.sz[ls_a] + a.sz[rs_a];
    b.sz[y] = b.sz[ls_b] + b.sz[rs_b];
}

int main() {
    
    n = read(); 
    
    pw[0] = 1;
    for (int i = 1; i <= 30; ++i)   pw[i] = pw[i - 1] << 1;
    
    for (int i = 1; i <= n; ++i)  a.ins(read());
    for (int i = 1; i <= n; ++i)  b.ins(read());
    
    DFS(1, 1, 0, 30);
    
    sort(ans + 1, ans + top + 1);
    
    for (int i = 1; i <= n; ++i)    printf("%d%c", ans[i], " \n"[i == n]);
    
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/VeniVidiVici/p/11536180.html