【代码源每日一题Div1】Dis「倍增求LCA维护树链异或和」

Dis

题目描述:

给出 n 个点的一棵树,每个点有各自的点权,多次询问两个点简单路径所构成点集的异或和。

思路:

一眼树上倍增求LCA维护路径异或

fa[i][j]表示从i开始往上走的第 2j个点

fuck[i][j]表示从i开始往上走共2i个点的点权的异或和

维护fuck数组的过程可以在求LCA的过程去维护

需要注意的是,跑LCA时

  • 如果xy在一条链上,在最后x=y的时候的答案是没有统计的,我们需要判断并进行计算
  • 如果不在一条链,则跑LCA的时候x和y最后会停留在LCA的下一层节点,也就是形成x-lca-y的一个情况,这三个的值我们都没有计算,需要特判计算一下
#include <bits/stdc++.h>
using namespace std;

#define endl '\n'
#define inf 0x3f3f3f3f
#define mod7 1000000007
#define mod9 998244353
#define m_p(a,b) make_pair(a, b)
#define mem(a,b) memset((a),(b),sizeof(a))
#define io ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
typedef long long ll;
typedef pair <int,int> pii;

#define MAX 1000000 + 50

int a, b;
int n, m;
int ar[MAX];
int tot;
int head[MAX];
struct ran{
    
    
    int to, nex;
}tr[MAX];
void add(int u, int v){
    
    
    tr[++tot].to = v;
    tr[tot].nex = head[u];
    head[u] = tot;
}

int deap[MAX];
int fa[MAX][22];
int fuck[MAX][22];
void dfs(int u, int last){
    
    
    deap[u] = deap[last] + 1;
    fa[u][0] = last;
    fuck[u][0] = ar[u];
    for(int i = 1; (1 << i) <= deap[u]; ++i){
    
    
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
        fuck[u][i] = (fuck[u][i-1]^fuck[fa[u][i-1]][i-1]);
    }
    for(int i = head[u]; i; i = tr[i].nex){
    
    
        int v = tr[i].to;
        if(v != last)dfs(v, u);
    }
}

int lca(int x, int y){
    
    
    if(deap[x] < deap[y])swap(x, y);
    int ans = 0;
    for(int i = 20; i >= 0; --i){
    
    
        if(deap[fa[x][i]] >= deap[y]){
    
    
            ans ^= fuck[x][i];
            x = fa[x][i];
        }
    }
    if(x == y){
    
    
        ans ^= ar[x];
        return ans;
    }
    for(int i = 20; i >= 0; --i){
    
    
        if(fa[x][i] != fa[y][i]){
    
    
            ans ^= fuck[x][i];
            ans ^= fuck[y][i];
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    ans ^= fuck[x][1] ^ fuck[y][1] ^ ar[fa[x][0]];
    return ans;
}


void work(){
    
    
    cin >> n >> m;
    for(int i = 1; i <= n; ++i)cin >> ar[i];
    for(int i = 1; i < n; ++i){
    
    
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }
    dfs(1, 0);
    for(int i = 1; i <= m; ++i){
    
    
        cin >> a >> b;
        cout<<lca(a, b)<<endl;
    }
}


int main(){
    
    
    io;
    work();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_51216553/article/details/127780872
今日推荐