思路比较好想, 但是蒟蒻码力不行啊
题意:
给定一些障碍,用这些障碍来切断根节点到所有叶子节点的路径,可以移动障碍,求总时间最少
思路:
因为士兵们可以同时移动,很容易想到,我们要求的是所有军队中所用时间最长的时间是多少,显然要最小化这个值,于是二分!!!
如何判定呢???
试想是不是一个点的深度越浅,那么它控制的节点一定不比深度深的差
那么,可以将给定的m个士兵在有限时间内(mid)往上跳,跳一个赚一个
目的则是控制根节点的所有儿子
然后就会出线这2中情况
- 1.跳不到根节点
- 2.跳得到根节点
对于情况1,我们在这个点做一个标记
对于情况2,我们记录它还有多少剩余时间,然后这些点有的是可以到根节点得另一颗子树去的。
如:
绿色节点可以跳到根节点然后到红色节点,于是就可以控制红色节点为根的子树
对于每一个可以到根节点的点,将它剩余时间从小到大排序(这样的点记为a)
再将未控制的根节点的儿子到根节点的距离排序 (点记为b)
然后看a是否能覆盖b即可
用倍增来优化
复杂度O(nlognlogn)
#include <bits/stdc++.h> using namespace std; #define LL long long const int maxn = 5e5 + 7; int head[maxn], cnt = 1, n, m, st[maxn], tot; int dep[maxn], f[maxn], F[maxn][22], ds[maxn][22]; struct Node {int v, nxt, w;} G[maxn << 1]; struct DATA {LL v; int id;} a[maxn], b[maxn]; bool vis[maxn]; int tt; void ins(int u, int v, int w) { G[cnt] = (Node) {v, head[u], w}; head[u] = cnt++; } void dfs(int x, int fa, int deep, int w) { dep[x] = deep; f[x] = fa; ds[x][0] = w; F[x][0] = fa; for (int i = 1; i <= 20 && F[F[x][i - 1]][i - 1]; ++i) { F[x][i] = F[F[x][i - 1]][i - 1]; ds[x][i] = ds[F[x][i - 1]][i - 1] + ds[x][i - 1]; } for (int i = head[x]; i; i = G[i].nxt) { int v = G[i].v; if(v == fa) continue; dfs(v, x, deep + 1, G[i].w); } } void put_tag(int x) { bool p = 1, q = 0; for (int i = head[x]; i; i = G[i].nxt) { int v = G[i].v; if(v == f[x]) continue; put_tag(v); p &= vis[v]; q = 1; } if(x != 1 && p && q) vis[x] = 1; } void init() { tot = tt = 0; memset(vis, 0, sizeof vis); } bool cmp(DATA a, DATA b) { return a.v < b.v; } bool check(int res) { init(); for (int i = 1; i <= m; ++i) { int x = st[i], t = 0; for (int j = 20; j >= 0; --j) if(F[x][j] && t + ds[x][j] <= res) { t += ds[x][j]; x = F[x][j];} if(x != 1) vis[x] = 1; else { x = st[i]; a[++tot].v = res - t; for (int j = 20; j >= 0; --j) if(F[x][j] > 1) x = F[x][j]; a[tot].id = x; } } put_tag(1); for (int i = head[1]; i; i = G[i].nxt) { int v = G[i].v; if(vis[v]) continue; b[++tt].v = G[i].w; b[tt].id = v; } sort(a + 1, a + tot + 1, cmp); sort(b + 1, b + tt + 1, cmp); int j = 1; for (int i = 1; i <= tot; ++i) { if(!vis[a[i].id]) vis[a[i].id] = 1; else if(a[i].v >= b[j].v) vis[b[j].id] = 1; while(vis[b[j].id] && j <= tt) ++j; } return j > tt; } int main() { scanf("%d", &n); int res = 0; for (int i = 1; i <= n - 1; ++i) { int u, v, w; scanf("%d%d%d", &u, &v, &w); ins(u, v, w); ins(v, u, w); res += w; } dfs(1, 0, 1, 0); scanf("%d", &m); for (int i = 1; i <= m; ++i) scanf("%d", &st[i]); int l = 0, r = res, ans = 0, ret = 0; for (int i = 1; i <= n; ++i) ret += (dep[i] == 2); if(m < ret) {puts("-1"); return 0;} while(l <= r) { int mid = (l + r) >> 1; if(check(mid)) {ans = mid; r = mid - 1;} else l = mid + 1; } printf("%d\n", ans); return 0; }