题面
题面中文的,不多说了吧
按照给定的序列走,其实就是走树上的一条路径,但是有点特殊,如果是 3 − 5 − 2 3-5-2 3−5−2,其实 5 5 5这个点只计算一次。
OK,树上路径问题,裸的树剖,树剖完成后,线段树维护区间和,更新的时候更新路径两个点在线段树上对应的区间,但是这样会多计算一次中间的点,因此除了第一个点以外,其余的点的查询结果减一即可。
#include <bits/stdc++.h>
#define mem(a, b) memset(a, b, sizeof a)
using namespace std;
const int N = 300100;
int head[N], nex[N * 2], to[N * 2], cnt;
void add(int a, int b) {
++cnt;
to[cnt] = b;
nex[cnt] = head[a];
head[a] = cnt;
}
int fa[N], depth[N], son[N], siz[N];
void dfs1(int x, int dep) {
depth[x] = dep;
siz[x] = 1;
int max_size = -1;
for (int i = head[x]; ~i; i = nex[i]) {
int y = to[i];
if (y == fa[x])continue;
if (!depth[y]) {
fa[y] = x;
dfs1(y, dep + 1);
siz[x] += siz[y];
if (max_size == -1 || max_size < siz[y]) {
max_size = siz[y];
son[x] = y;
}
}
}
}
int dfn[N], rk[N], num, top[N];
void dfs2(int x, int tp) {
dfn[x] = ++num;
rk[num] = x;
top[x] = tp;
if (son[x] == -1)return;
dfs2(son[x], tp);
for (int i = head[x]; ~i; i = nex[i]) {
int y = to[i];
if (y != son[x] && y != fa[x]) {
dfs2(y, y);
}
}
}
void init() {
mem(head, -1), mem(nex, -1), cnt = 0;
num = 0, mem(dfn, -1), mem(rk, -1), mem(top, -1);
mem(fa, -1), mem(depth, 0), mem(siz, 0), mem(son, -1);
}
struct p {
int l, r, sum, lazy;
};
typedef struct SegementTree {
p c[N * 4];
void build(int l, int r, int k) {
c[k].l = l;
c[k].r = r;
c[k].sum = 0;
c[k].lazy = 0;
if (l == r) {
return;
}
int mid = (l + r) >> 1;
build(l, mid, k << 1);
build(mid + 1, r, k << 1 | 1);
}
void down(int k) {
if (c[k].lazy) {
c[k << 1].sum += c[k].lazy;
c[k << 1 | 1].sum += c[k].lazy;
c[k << 1].lazy += c[k].lazy;
c[k << 1 | 1].lazy += c[k].lazy;
c[k].lazy = 0;
}
}
int query(int l, int r, int k) {
if (l <= c[k].l && r >= c[k].r)return c[k].sum;
down(k);
int mid = (c[k].l + c[k].r) >> 1;
int ans = 0;
if (l <= mid)ans += query(l, r, k << 1);
if (r > mid)ans += query(l, r, k << 1 | 1);
return ans;
}
void update(int l, int r, int k, int d) {
if (l <= c[k].l && r >= c[k].r) {
c[k].sum += d * (c[k].r - c[k].l + 1);
c[k].lazy += d;
return;
}
down(k);
int mid = (c[k].l + c[k].r) >> 1;
if (l <= mid)update(l, r, k << 1, d);
if (r > mid)update(l, r, k << 1 | 1, d);
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
}
int getAns(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]]) {
swap(x, y);
}
ans += query(dfn[top[x]], dfn[x], 1);
x = fa[top[x]];
}
if (depth[x] < depth[y])swap(x, y);
ans += query(dfn[y], dfn[x], 1);
return ans;
}
void change(int x, int y, int d) {
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]]) {
swap(x, y);
}
update(dfn[top[x]], dfn[x], 1, d);
x = fa[top[x]];
}
if (depth[x] < depth[y])swap(x, y);
update(dfn[y], dfn[x], 1, d);
}
} ST;
ST st;
int n, Q;
int root;
int vec[N];
int main()
{
freopen("in.txt", "r", stdin);
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
init();
for (int i = 1; i <= n; i++)cin >> vec[i];
root = vec[1];
for (int i = 1; i < n; i++) {
int a, b;
cin >> a >> b;
add(a, b);
add(b, a);
}
dfs1(root, 1);
dfs2(root, root);
int x;
x = root;
st.build(1, num, 1);
for (int i = 2; i <= n; i++) {
st.change(vec[i], x, 1);
x = vec[i];
}
int ans;
for (int i = 1; i <= n; i++) {
ans = st.getAns(i, i);
if (i != root)ans--;
cout << ans << "\n";
}
return 0;
}