[LOJ#2386]. 「USACO 2018.01 Platinum」Cow at Large[点分治]

题意

题目链接

分析

  • 假设当前的根为 rt ,我们能够在奶牛到达 \(u\) 之时拦住它,当且仅当到叶子节点到 \(u\) 的最短距离 \(mn_u \le dis_u\) 。容易发现,合法的区域是许多棵子树,而我们要求的就是有多少棵子树。
  • 由于除了以 rt 为根的子树都可以用 \(\sum\limits_{x\in subtree} 2-deg(x)\) 的形式表示 (如果 rt 是叶子特判掉即可),于是可以将问题转化成有多少个点满足 \(mn_u\le dis_u​\)
  • 考虑点分治,先补集转化这样不用处理负权。每次求对于 \(a\) 有多少 \(b\) 满足 \(dis_a+dis_b<mn_b\) 。把所有路径按照 \(mn_b-dis_b​\) 的大小排序后在序列上二分差后缀和即可。最后要容斥减去子树内的方案数。
  • 时间复杂度 \(O(nlog^2n)​\)

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline bool Max(T &a, T b){return a < b ? a = b, 1 : 0;}
template <typename T> inline bool Min(T &a, T b){return a > b ? a = b, 1 : 0;}
const int N = 7e4 + 7, inf = 0x3f3f3f3f;
int n, rt, edc, sn;
int mn[N], deg[N], ans[N], mxs[N], head[N], son[N];
bool vis[N];
struct edge {
    int lst, to;
    edge(){}edge(int lst, int to):lst(lst), to(to){}
}e[N << 1];
void Add(int a, int b){
    ++deg[a], ++deg[b];
    e[++edc] = edge(head[a], b), head[a] = edc;
    e[++edc] = edge(head[b], a), head[b] = edc;
}
void bfs() {
    queue<int>Q;
    memset(mn, 0x3f, sizeof mn);
    rep(i, 1, n) if(deg[i] == 1) {
        mn[i] = 0, Q.push(i);
    }
    while(!Q.empty()) {
        int u = Q.front();Q.pop();
        go(u)if(mn[v] == inf) {
            mn[v] = mn[u] + 1;
            Q.push(v);
        }
    }
}
int tp;
typedef pair<int, int> pii;
#define mp make_pair
pii suf[N];
void getrt(int u, int fa) {
    mxs[u] = 0;son[u] = 1;
    go(u)if(!vis[v] && v ^ fa) {
        getrt(v, u);
        son[u] += son[v];
        Max(mxs[u], son[v]);
    }
    Max(mxs[u], sn - son[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}
void getdep(int u, int fa, int dis) {
    if(mn[u] - dis > 0) suf[++tp] = mp(mn[u] - dis, 2 - deg[u]);
    go(u)if(!vis[v] && v ^ fa) {
        getdep(v, u, dis + 1);
    }
}
void getans(int u, int fa, int dis, int f) {
    int gg = upper_bound(suf + 1, suf + 1 + tp, mp(dis, inf)) - suf;
    if(gg != tp + 1)
    ans[u] += f * suf[gg].second;
    go(u)if(!vis[v] && v ^ fa) {
        getans(v, u, dis + 1, f);
    }
}
void solve(int u) {
    vis[u] = 1;
    tp = 0;
    getdep(u, 0, 0);
    sort(suf + 1, suf + 1 + tp);
    for(int j = tp - 1; j >= 1; --j) suf[j].second += suf[j + 1].second;
    getans(u, 0, 0, 1);
    
    go(u)if(!vis[v]) {
        tp = 0;
        getdep(v, u, 1);
        sort(suf + 1, suf + 1 + tp);
        for(int j = tp - 1; j >= 1; --j) suf[j].second += suf[j + 1].second;
        getans(v, u, 1, -1);
    }
    
    int old = sn;
    go(u)if(!vis[v]) {
        if(son[v] > son[u])
            sn = old - son[u];
        else
            sn = son[v];
        rt = 0, getrt(v, u), solve(rt);
    }
}
int main() {
    n = gi();
    rep(i, 1, n - 1) Add(gi(), gi());
    bfs();
    sn = n, mxs[rt = 0] = n + 1, getrt(1, 0), solve(rt);
    rep(i, 1, n) printf("%d\n", deg[i] == 1 ? 1 : 2 - ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yqgAKIOI/p/10548768.html
0条评论
添加一条新回复