ZJOI2019 语言

ZJOI2019 语言

若城市 \(u,v\) 能够开展贸易,则称 \(u\) 可以到达

那么不难发现某个点能到达的所有点连通

如果路径 \(s \to t\) 包含点 \(u\)\(s,t\) 则是 \(u\) 的两个极远点,那么 \(u\) 的生成树为连通 \(u\) 的所有极远点的最小生成树

于是,我们需要事先思考子问题:给定树上若干的点,如何求最小生成树大小?

这应该是个经典问题。可能要用到一点建虚树的思想。

为了方便,我们硬点存在极远点 111 ,最后再计算其影响。

我们给这些点按照 dfs\mathrm{dfs}dfs 序排序,然后一个一个地加进去。假设当前加入 uuu 号点,之前加入的点构成点集 SSS ,SSS 的最小生成树是 TTT。显然,为了使 uuu 和 TTT 连通,只需要找到 TTT 中离 uuu 最近的点连起来就好了,这“最近的点”正是 SSS 中的某个点与 uuu 的 lca\mathrm{lca}lca 。如果你对 dfs\mathrm{dfs}dfs 序有较好的感觉,其实所谓 SSS 中的某个点,就是上一个加进去的点 vvv !所以 uuu 的贡献是 depu−deplca(u,v)\mathrm{dep}u - \mathrm{dep}{\mathrm{lca}(u,,v)}depu​−deplca(u,v)​ 。别忘了硬点了 111 号点,最后还要减 dep所有点的 lca\mathrm{dep}_{\text{所有点的} \mathrm{lca}}dep所有点的 lca​ 。

上面那段可能有点绕,强烈建议画个图感受一下!

暴力加点似乎不容易被利用,此时最擅长分治的线段树终于重磅登场了。

仍然是以 dfs\mathrm{dfs}dfs 序为下标,线段树的叶子结点控制原树上每个点选或不选,每个结点表示了硬点 111 后一段区间的点集的最小生成树大小,并存下每个点集中下标最小和最大的点。类似地,pushup\mathrm{pushup}pushup 时要减去 deplca(u,v)\mathrm{dep}{\mathrm{lca}(u,,v)}deplca(u,v)​ 。此时的 uuu 当然是左孩子已选下标最大的,vvv 当然是右孩子已选下标最小的。查询的话只要用根结点的答案减去 dep所有点的 lca\mathrm{dep}{\text{所有点的} \mathrm{lca}}dep所有点的 lca​ 即可。

白白套一个线段树不是多 log⁡\loglog 吗?别急,正解已经浮出水面了。

想象每个结点都有一棵线段树。注意到有一条路径 s→ts \to ts→t ,我们会给该路径上的每棵线段树都选上 s,ts,,ts,t 两点。按照套路,求出 lca=lca(u,v)\mathrm{lca} = \mathrm{lca}(u,,v)lca=lca(u,v) ,等同于 s→lca,t→lcas \to \mathrm{lca},,t\to \mathrm{lca}s→lca,t→lca 两条链上执行线段树修改。仍然是套路,链上修改可以转换成树上差分打标记,s,t,lca,fa[lca]s,,t,,\mathrm{lca},,\mathrm{fa}[\mathrm{lca}]s,t,lca,fa[lca] 分别打上 1,1,−1,−11,,1,,-1,,-11,1,−1,−1 。

慢着……线段树打标记是可以,但每个结点要把自己的信息继承给父亲,这怎么做?
动态开点 + 线段树合并!

至此,[ZJOI2019]语言 的大致做法讲完了。nnn 次线段树合并是 O(nlog⁡n)O(n \log n)O(nlogn) 的,而 nnn 次修改要求 nlog⁡nn \log nnlogn 次 lca\mathrm{lca}lca ,如果树剖或倍增,复杂度是 O(nlog⁡2n)O(n \log^2 n)O(nlog2n) ,如果使用欧拉序 st\mathrm{st}st 表,复杂度是优美的 O(nlog⁡n)O(n \log n)O(nlogn) 。
代码实现

刚才求的 (u,v)(u,,v)(u,v) 没有强制 u<vu < vu<v ,答案要除以 222 。

#include <cmath>
#include <cstdio>
#include <vector>
#include <algorithm>

const int N = 200005, V = 6400005, L = 18;

int n, m, tms, o[N], ft[N], dep[N], dfn[N], st[L][N];
std::vector<int> to[N], del[N];
long long ans;

inline int getLca(int u, int v);

struct SegmentTree {
    int tot, c[V], f[V], s[V], t[V], ls[V], rs[V], rt[N];

    inline void pushUp(int u) {
        f[u] = f[ls[u]] + f[rs[u]] - dep[getLca(t[ls[u]], s[rs[u]])];
        s[u] = s[ls[u]] ? s[ls[u]] : s[rs[u]];
        t[u] = t[rs[u]] ? t[rs[u]] : t[ls[u]];
    }
    inline int query(int u) { return f[u] - dep[getLca(s[u], t[u])]; }
    void modify(int &u, int l, int r, int p, int x) {
        if (!u) { u = ++tot; }
        if (l == r) {
            c[u] += x; f[u] = c[u] ? dep[p] : 0; s[u] = t[u] = c[u] ? p : 0;
            return;
        }
        int mid = l + r >> 1;
        if (dfn[p] <= mid) { modify(ls[u], l, mid, p, x); }
        else { modify(rs[u], mid + 1, r, p, x); }
        pushUp(u);
    }
    void merge(int &u, int v, int l, int r) {
        if (!u || !v) { u |= v; return; }
        if (l == r) {
            c[u] += c[v]; f[u] |= f[v]; s[u] |= s[v]; t[u] |= t[v];
            return;
        }
        int mid = l + r >> 1;
        merge(ls[u], ls[v], l, mid); merge(rs[u], rs[v], mid + 1, r); pushUp(u);
    }
} smt;

void build() {
    for (int i = 1; i <= tms; i++) { o[i] = log(i) / log(2) + 1e-7; }
    for (int i = 1; i <= o[tms]; i++) {
        for (int j = 1, u, v; j + (1 << i) - 1 <= tms; j++) {
            u = st[i - 1][j]; v = st[i - 1][j + (1 << i - 1)];
            st[i][j] = dep[u] < dep[v] ? u : v;
        }
    }
}
inline int getLca(int u, int v) {
    if (!u || !v) { return 0; } u = dfn[u]; v = dfn[v];
    if (u > v) { std::swap(u, v); }
    int d = o[v - u + 1]; u = st[d][u]; v = st[d][v - (1 << d) + 1];
    return dep[u] < dep[v] ? u : v;
}

void dfs(int u, int fa) {
    ft[u] = fa; dep[u] = dep[fa] + 1; st[0][++tms] = u; dfn[u] = tms;
    for (auto v : to[u]) { if (v != fa) { dfs(v, u); st[0][++tms] = u; } }
}
void solve(int u) {
    for (auto v : to[u]) { if (v != ft[u]) { solve(v); } }
    for (auto v : del[u]) { smt.modify(smt.rt[u], 1, tms, v, -1); }
    ans += smt.query(smt.rt[u]); smt.merge(smt.rt[ft[u]], smt.rt[u], 1, tms);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 2, u, v; i <= n; i++) {
        scanf("%d%d", &u, &v);
        to[u].push_back(v); to[v].push_back(u);
    }
    dfs(1, 0); build();
    for (int u, v, lca; m; m--) {
        scanf("%d%d", &u, &v); lca = getLca(u, v);
        smt.modify(smt.rt[u], 1, tms, u, 1); smt.modify(smt.rt[u], 1, tms, v, 1);
        smt.modify(smt.rt[v], 1, tms, u, 1); smt.modify(smt.rt[v], 1, tms, v, 1);
        del[lca].push_back(u); del[ft[lca]].push_back(u);
        del[lca].push_back(v); del[ft[lca]].push_back(v);
    }
    solve(1); printf("%lld\n", ans >> 1);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/wawawa8/p/11094612.html