Petrozavodsk Winter Camp, Day 8, 2014, Second Trip

给你一棵树,每次询问一个(a,b),问有多少有路径与a-b没有交集

找lca

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = int(j); i <= int(k); ++ i) 
#define dwn(i, j, k) for (int i = int(j); i >= int(k); -- i)
typedef long long LL;
typedef pair<int, int> P;
const int N = 1e5 + 7;
vector<int> g[N];
int dep[N], fa[N][20], sz[N]; LL f1[N], f2[N]; int n, q;
LL calc(LL x) {
    return x * (x + 1) / 2LL;
}
void dfs(int u, int f) {
    dep[u] = dep[f] + 1;
    fa[u][0] = f;
    sz[u] = 1;
    for (int v: g[u]) 
        if (v != f) {
            dfs(v, u);
            sz[u] += sz[v];
        }
}
void dfs2(int u, int f) {
    for (int v: g[u])
        if (v != f) {
            dfs2(v, u);
            f1[u] += calc(sz[v]);
        }
}
void dfs3(int u, int f) {
    if (u != 1) f2[u] = f2[f] + f1[f] - calc(sz[u]);
    for (int v: g[u]) 
        if (v != f) dfs3(v, u);
}
int lca(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 19; i >= 0; --i) 
        if (dep[u] - (1 << i) >= dep[v]) u = fa[u][i];
    if (u == v) return u;
    for (int i = 19; i >= 0; --i) 
        if (fa[u][i] && fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    return fa[u][0]; 
}
int find(int u, int d) {
    dwn(i, 19, 0) if (d >= (1 << i)) u = fa[u][d]; return u;
}
LL solve(int x, int y) {    
    if (dep[x] < dep[y]) swap(x, y);
    int lc = lca(x, y); // dep[x] >= dep[y]
    LL ret = 0;
    if (y == lc) {
        ret = f1[x] + f2[x] - f2[y] + calc(n - sz[y]);
    }
    else {
        ret = f1[x] + f1[y] + calc(n - sz[lc]);
        ret += f2[x] - f2[lc];
        int t1 = find(y, dep[lc] - dep[y] - 1);
        ret -= calc(sz[t1]);
        ret += f2[y] - f2[lc];
        int t2 = find(x, dep[lc] - dep[x] - 1);
        ret -= calc(sz[t2]);
        ret -= f1[lc] - calc(sz[t1]) - calc(sz[t2]);
    }
    return ret;
}
int main() {
    scanf("%d%d", &n, &q);
    rep(i, 1, n - 1) {
        int x, y;
        scanf("%d%d", &x, &y);
        g[x].push_back(y); 
        g[y].push_back(x);
    }
    dfs(1, 0);
    rep(j, 1, 19) rep(i, 1, n) fa[i][j] = fa[fa[i][j - 1]][j - 1];
    dfs2(1, 0);
    dfs3(1, 0);
    // rep(i, 1, n) cout << f2[i] << ' '; cout << '\n';
    while (q --) {
        int x, y;
        scanf("%d%d", &x, &y);
        cout << solve(x, y) << '\n';
    }
}
/*
6 2 
1 2 
3 2 
3 4 
3 5 
6 3
5 4 
1 6
*/

猜你喜欢

转载自www.cnblogs.com/tempestT/p/10661546.html