虚树学习小记

问题引入:

在一些树上询问中,有一些询问,它们每次询问多个点的产生的什么什么东西,询问很多,但是总点数很少。

如果我们去暴力做,每次都要遍历整个树,复杂度就变成了 O ( n m )

如何利用询问的总点数少的特点去求解呢?

好吧,这就是虚树要干的事。

定义:

一般树的什么什么东西都和两点最短路径有关,而最短路径又和lca有关。

所以虚树就是所有询问的点和它们两两的lca构成的树。


什么?两两之间的lca?那不整棵树了吗?

*n个点两两之间的lca不同最多只有n-1个。

求点:

把所有的询问点按照dfs序排序。

相邻两个点的lca集合即是两两之间的lca集合。

简单归纳即可证明。

构边:

求出所有点后,按照dfs序搞。

设当前点为x,求出它于栈顶的lca,一直退栈直到当前栈顶不在那个lca的子树里为止,退栈的同时就可以求出每个点的父亲。

栈顶的父亲就是栈顶底下的那个点。

记得最后要清空栈。

例题:

【HNOI2014】世界树

分析:
显然构出虚树,然后做个树形dp。

Code:

#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define abs(a) ((a) > 0 ? (a) : -(a))
#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))
using namespace std;

const int N = 6e5 + 5;

const int inf = 1e9;

int n, x, y, m, h[N], hh[N], Q, d[N], z0, z[N], f[N];
int next[N], to[N], final[N], tot;

void link(int x, int y) {
    next[++ tot] = final[x], to[tot] = y, final[x] = tot;
    next[++ tot] = final[y], to[tot] = x, final[y] = tot;
}

int dep[N], fa[19][N], p[N], q[N], td, siz[N];

void dg(int x) {
    p[x] = ++ td;
    siz[x] = 1;
    for(int i = final[x]; i; i = next[i]) {
        int y = to[i]; if(p[y]) continue;
        dep[y] = dep[x] + 1;
        fa[0][y] = x;
        dg(y);
        siz[x] += siz[y];
    }
    q[x] = td;
}

int lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    fd(i, 18, 0) if(dep[fa[i][x]] >= dep[y]) x = fa[i][x];
    if(x == y) return x;
    fd(i, 18, 0) if(fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y];
    return fa[0][x];
}

int zen(int x, int y) {
    fd(i, 18, 0) if(dep[fa[i][x]] >= y) x = fa[i][x];
    return x;
}

int cmp(int x, int y) {return p[x] < p[y];}

int bz[N], ans[N];

struct edge {
    int next[N], to[N], final[N], tot;
    void link(int x, int y) {
        next[++ tot] = final[x], to[tot] = y, final[x] = tot;
    }
} e;

int mi[N], dis[N], li[N];

void dfs(int x) {
    if(bz[x] == Q) mi[x] = x, dis[x] = 0; else dis[x] = inf;
    for(int i = e.final[x]; i; i = e.next[i]) {
        int y = e.to[i];
        dfs(y);
        int p = dis[y] + dep[y] - dep[x];
        if(p < dis[x] || p == dis[x] && mi[y] < mi[x])
            dis[x] = p, mi[x] = mi[y];
    }
}

void dd(int x) {
    for(int i = e.final[x]; i; i = e.next[i]) {
        int y = e.to[i];
        int p = dis[x] + dep[y] - dep[x];
        if(p < dis[y] || p == dis[y] && mi[x] < mi[y])
            dis[y] = p, mi[y] = mi[x];
        dd(y);
    }
}

int main() {
    freopen("worldtree.in", "r", stdin);
    freopen("worldtree.out", "w", stdout);
    scanf("%d", &n);
    fo(i, 1, n - 1) {
        scanf("%d %d", &x, &y);
        link(x, y);
    }
    dep[1] = 1; dg(1);
    fo(i, 1, 18) fo(j, 1, n) fa[i][j] = fa[i - 1][fa[i - 1][j]];
    for(scanf("%d", &Q); Q; Q --) {
        scanf("%d", &m);
        fo(i, 1, m) scanf("%d", &h[i]), hh[i] = h[i];
        sort(h + 1, h + m + 1, cmp);
        d[0] = m; fo(i, 1, m) d[i] = h[i]; d[++ d[0]] = 1;
        fo(i, 2, m) d[++ d[0]] = lca(h[i - 1], h[i]);
        sort(d + 1, d + d[0] + 1, cmp);
        int d0 = 0; fo(i, 1, d[0]) if(i == 1 || d[i] != d[i - 1]) d[++ d0] = d[i];
        z[z0 = 1] = 1;
        fo(i, 2, d0) {
            int x = d[i];
            while(z0 && (p[x] < p[z[z0]] || p[x] > q[z[z0]])) f[z[z0]] = z[z0 - 1], z0 --;
            z[++ z0] = x;
        }
        while(z0) f[z[z0]] = z[z0 - 1], z0 --;

        fo(i, 2, d0) e.link(f[d[i]], d[i]);
        fo(i, 1, m) bz[h[i]] = Q;
        dfs(1); dd(1);
        fo(i, 1, m) ans[h[i]] = 0;
        fo(i, 1, d0) li[d[i]] = siz[d[i]];
        fo(i, 1, d0) {
            int x = d[i], y = f[x];
            int o = zen(x, dep[y] + 1);
            li[y] -= siz[o];
        }
        fo(i, 1, d0) ans[mi[d[i]]] += li[d[i]];
        fo(i, 1, d0) {
            int x = d[i], y = f[x];
            if(x == 1) continue;
            int z = (dep[x] + dep[y] + dis[x] - dis[y]) / 2;
            if((dep[x] + dep[y] + dis[x] - dis[y]) % 2 == 0)
                if(mi[y] > mi[x]) z --;
            int o = zen(x, max(z + 1, dep[y] + 1));
            if(z < dep[x]) ans[mi[x]] += siz[o] - siz[x];
            int u = zen(x, dep[y] + 1);
            if(z >= dep[y]) ans[mi[y]] += siz[u] - siz[o];
        }

        fo(i, 1, m) printf("%d ", ans[hh[i]]); printf("\n");

        //clear e
        fo(i, 1, e.tot) e.next[i] = 0;
        fo(i, 1, d0) e.final[d[i]] = 0;
        e.tot = 0;
    }
}

猜你喜欢

转载自blog.csdn.net/cold_chair/article/details/80867739