Luogu 4103 [HEOI2014]大工程

BZOJ 3611

明明在BZOJ上是$6s$的时限,怎么到Luogu上就变成$4s$了……

按照套路建出虚树,点之间的距离可以变成边权表示在虚树上,然后考虑如何树形$dp$。

最大值和最小值应当比较简单,类似于树形$dp$求树的直径的方法,设$f_x$表示$x$的子树中的关键点到$x$的最远距离,$g_x$则表示最近距离。

对于每一个$x$,如果$x$是一个关键点,则有$f_x = g_x = 0$,否则$f_x = -inf, g_x = inf$。

对于$x$的每一个儿子$y$,先用$f_x + f_y + val(x, y)$和$g_x + g_y + val(x, y)$更新$ans$,再用$f_x + val(x, y)$和$g_x + val(x, y)$更新$f_x$和$g_x$。

考虑一下怎么求所有边权的总和,我们发现一条边被计算的次数等于把这条边断开所分割成的两个联通块中的关键点的数量的乘积,那么它对答案的贡献就是这条边的边权乘上这个乘积。

其实除了根结点,所有的点都唯一对应了一条入边,我们先预处理一个点$x$,子树中关键点的数目$siz_x$,那么每一条入边$inEdge$的贡献就是$val(inEdge) * siz_x * (siz_{root} - x)$。

一开始犯了一个错误,就是不能把$1$强制作根,类似于点分治的时候需要减掉的贡献,这样会导致最长的路径变长,我们只要找到第一个入栈的点为$root$即可。

我的代码在$Luogu$上需要一发$O2$,感觉写欧拉序求$lca$会更好。

时间复杂度$O(nlogn)$。

Code:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

const int N = 1e6 + 5;
const int Lg = 22;
const ll inf = 1LL << 60;

int n, qn, tot = 0, head[N], fa[N][Lg], dep[N];
int rt, top = 0, sta[N * 2], dfsc = 0, in[N], out[N], a[N * 2], siz[N];
ll sum, maxD, minD, f[N], g[N];
bool vis[N], flag[N];

struct Edge {
    int to, nxt, val;
} e[N << 1];

inline void add(int from, int to, int val = 0) {
    e[++tot].to = to;
    e[tot].val = val;
    e[tot].nxt = head[from];
    head[from] = tot;
}  

namespace IOread {
    const int L = 1 << 15;
    
    char buffer[L], *S, *T;
    
    inline char Getchar() {
        if(S == T) {
            T = (S = buffer) + fread(buffer, 1, L, stdin);
            if(S == T) return EOF;
        }
        return *S++;
    }
    
    template <class T> 
    inline void read(T &X) {
        char ch; T op = 1;
        for(ch = Getchar(); ch > '9' || ch < '0'; ch = Getchar())
            if(ch == '-') op = -1;
        for(X = 0; ch >= '0' && ch <= '9'; ch = Getchar()) 
            X = (X << 1) + (X << 3) + ch - '0'; 
        X *= op;
    }
    
} using namespace IOread;

bool cmp(int x, int y) {
    int dfx = x > 0 ? in[x] : out[-x];
    int dfy = y > 0 ? in[y] : out[-y];
    return dfx < dfy;
}

inline void swap(int &x, int &y) {
    int t = x; x = y; y = t;
}

template <typename T>
inline void chkMax(T &x, T y) {
    if(y > x) x = y;
}

template <typename T>
inline void chkMin(T &x, T y) {
    if(y < x) x = y;
}

void dfs(int x, int fat, int depth) {
    in[x] = ++dfsc, fa[x][0] = fat, dep[x] = depth;
    for(int i = 1; i <= 20; i++)
        fa[x][i] = fa[fa[x][i - 1]][i - 1];
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fat) continue;
        dfs(y, x, depth + 1);
    }
    out[x] = ++dfsc;
}

inline int getLca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    for(int i = 20; i >= 0; i--)
        if(dep[fa[x][i]] >= dep[y])
            x = fa[x][i];
    if(x == y) return x;
    for(int i = 20; i >= 0; i--)
        if(fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0]; 
}

inline int getDis(int x, int y) {
    int z = getLca(x, y);
    return dep[x] + dep[y] - 2 * dep[z];
}

void dfs1(int x, int fat) {
    if(flag[x]) ++siz[x], g[x] = f[x] = 0;
    else g[x] = inf, f[x] = -inf;
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fat) continue;
        dfs1(y, x);
        siz[x] += siz[y];
        
        chkMax(maxD, f[x] + f[y] + e[i].val);
        chkMax(f[x], f[y] + e[i].val);
        
        chkMin(minD, g[x] + g[y] + e[i].val);
        chkMin(g[x], g[y] + e[i].val);
    }
}

void dfs2(int x, int fat, int inEdge) {
    if(fat != 0) 
        sum += 1LL * e[inEdge].val * (siz[rt] - siz[x]) * siz[x];
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fat) continue;
        dfs2(y, x, i);
    } 
}

void solve() {
    int K, cnt; read(K);
    for(int i = 1; i <= K; i++) {
        read(a[i]);
        if(!vis[a[i]]) {
            vis[a[i]] = 1;
            flag[a[i]] = 1;
        }
    }
    
    cnt = K;
    sort(a + 1, a + 1 + K, cmp);
    for(int i = 1; i < cnt; i++) {
        int now = getLca(a[i], a[i + 1]);
        if(!vis[now]) {
            vis[now] = 1;
            a[++cnt] = now;
        }
    }
    
    for(int cur = cnt, i = 1; i <= cur; i++)
        a[++cnt] = -a[i];
//    if(!vis[1]) a[++cnt] = 1, a[++cnt] = -1, vis[1] = 1;
    sort(a + 1, a + 1 + cnt, cmp);
    
/*    for(int i = 1; i <= cnt; i++)
        printf("%d ", a[i]);
    printf("\n");    */
    
    top = rt = 0;
    for(int i = 1; i <= cnt; i++) {
        if(a[i] > 0) {
            sta[++top] = a[i];
            if(!rt) rt = a[i];
        } else {
            int x = sta[top--], y = sta[top];
            if(y) {
                int nowDis = getDis(x, y);
                add(x, y, nowDis), add(y, x, nowDis);
            }
        }
    }
    
    sum = 0LL, minD = inf, maxD = -inf;
    dfs1(rt, 0), dfs2(rt, 0, 0);
    
    printf("%lld %lld %lld\n", sum, minD, maxD);
    
    tot = 0;
    for(int i = 1; i <= cnt; i++) 
        if(a[i] > 0) {
            vis[a[i]] = flag[a[i]] = 0;
            head[a[i]] = siz[a[i]] = 0;
//            f[a[i]] = -inf, g[a[i]] = inf;
        }
}

int main() {
    read(n);
    for(int x, y, i = 1; i < n; i++) {
        read(x), read(y);
        add(x, y), add(y, x);
    }
    dfs(1, 0, 1);
    
    tot = 0; memset(head, 0, sizeof(head));
    for(read(qn); qn--; ) solve();
    
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/CzxingcHen/p/9726299.html