Description
Input
Output
Sample Input
5 2
4 3 2 5 3
1 2
1 3
3 4
3 5
2 5
3 4
Sample Output
13
7
Data Constraint
思路
可以分别考虑每一个二进制位对答案的贡献。即,对于位2x ,维护从每一个点t 出发,向上 2i的距离之
内,与t 距离为d 满足 d&2x=2x且点权的二进制表示中包含 的点的个数就行了。
由于路径有向上的部分,也有向下的部分,因此还需要维护满足 的点的个数在从 倍增的时候使用。
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int n, q;
const int N = 300010;
const int M = 21;
int ai[N], dep[N];
int cnt[M][N], fa[M][N];
ll fu[M][N], fd[M][N];
int nxt[N * 2], bi[N * 2], lst[N], tot;
void dfs(int t, int f = 0, int d = 1)
{
fa[0][t] = f; dep[t] = d;
for (int i = 0; i <= 20; ++ i) cnt[i][t] = cnt[i][f] + ((ai[t] & (1 << i)) == 0);
for (int i = lst[t]; i; i = nxt[i]) if (bi[i] != f)
dfs(bi[i], t, d + 1);
}
int lca(int a, int b)
{
if (dep[a] < dep[b]) swap(a, b);
if (dep[a] > dep[b])
{
for (int i = 20; ~i; -- i) if (dep[fa[i][a]] > dep[b]) a = fa[i][a];
a = fa[0][a];
}
if (a != b)
{
for (int i = 20; ~i; -- i) if (fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b];
a = fa[0][a], b = fa[0][b];
}
return a;
}
ll work1(int u, int l)
{
ll ass = 0;
for (int i = 20; ~i; -- i)
if (dep[fa[i][u]] >= dep[l])
{
ass += fu[i][u];
u = fa[i][u];
ass += (1ll << i) * (cnt[i][u] - cnt[i][l]);
}
return ass;
}
ll work2(int v, int d)
{
ll ass = 0; int s = v; d ++;
for (int i = 0; i <= 20; ++ i)
if (d & (1 << i))
{
ass += fd[i][v];
ass += (1ll << i) * (cnt[i][s] - cnt[i][v]);
v = fa[i][v];
}
return ass;
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; ++ i) scanf("%d", &ai[i]);
for (int i = 1; i < n; ++ i)
{
int u, v;
scanf("%d%d", &u, &v);
nxt[++ tot] = lst[u]; lst[u] = tot; bi[tot] = v;
nxt[++ tot] = lst[v]; lst[v] = tot; bi[tot] = u;
}
dfs(1);
for (int i = 1; i <= n; ++ i)
fu[0][i] = fd[0][i] = ai[i];
for (int i = 1; i <= 20; ++ i)
for (int j = 1; j <= n; ++ j)
{
fa[i][j] = fa[i - 1][fa[i - 1][j]];
fu[i][j] = fu[i - 1][j] + fu[i - 1][fa[i - 1][j]] + (1ll << (i - 1)) *
(cnt[i - 1][fa[i - 1][j]] - cnt[i - 1][fa[i][j]]);
fd[i][j] = fd[i - 1][j] + fd[i - 1][fa[i - 1][j]] + (1ll << (i - 1)) *
(cnt[i - 1][j] - cnt[i - 1][fa[i - 1][j]]);
}
for (int i = 1; i <= q; ++ i)
{
int u, v; scanf("%d%d", &u, &v); int l = lca(u, v), d = dep[u] + dep[v] - dep[l] * 2;
printf("%lld\n", work1(u, fa[0][l])+ work2(v, d) - work2(l, dep[u] - dep[l]));
}
}