「WC2018」通道-边分治+虚树+DP

Description

给定三棵树,最大化 d i s t a ( x , y ) + d i s t b ( x , y ) + d i s t c ( x , y ) dista(x,y)+distb(x,y)+distc(x,y)

n 100000 n \leq 100000

Solution

先考虑两棵树怎么做。

点对 ( x , y ) (x,y) 的贡献为 d e p a x + d e p a y d e p a l c a ( x , y ) × 2 + d i s t b ( x , y ) depa_x+depa_y-depa_{lca(x,y)}\times 2 + distb(x,y)

对于第二棵树上的每个节点 x x ,连一个点 x x' ,边权为 d e p a x depa_x 。那么 d i s t b ( x , y ) = d e p a x + d e p a y + d i s t b ( x , y ) distb(x',y')=depa_x+depa_y+distb(x,y) 。结合上式,那么我们的目标是最大化 d i s t ( x , y ) dist(x',y') 其中 x , y x',y' 属于 l c a lca 的不同子树。

对于边权非负的图,有这样的性质:跨越集合 A , B A,B 的最长链的端点一定可以是 A A 中最长链的端点和 B B 最长链的端点中的某两个。可以用类似证明树的直径的方式来证明这一结论。

因此,在第一棵树上进行树形DP,维护当前集合中最长链的两个端点,在合并时更新答案即可。

现在考虑三棵树的情况。

在第三棵树上进行边分治,对于当前边分治的点集在第二棵树上建立虚树,对于第一棵树上的每个节点 x x ,新建点 x x' 与其相连,边权为 d e p b x + d i s t c x depb_x+distc_x d i s t c x distc_x 为点 x x 在第三棵树上到当前分治中心的距离。

所以,在第二棵树上的虚树进行树形DP,分别维护属于第三棵树上两个分治联通块的点集的两个端点,在合并时更新答案即可。

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;
const int maxn = 200005;

int n, Log[maxn];
int cnt1, cnt2, q1[maxn], q2[maxn], typ[maxn];
lint w[maxn], k, ans;

inline lint gi()
{
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    lint sum = 0;
    while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
    return sum;
}

namespace t1
{

    struct edge 
    {
        int to, next;
        lint w;    
    } e[maxn * 2];
    int h[maxn], tot;
    int dep[maxn], Time, ord[maxn], dfn[maxn], Min[21][maxn];
    lint dis[maxn];

    inline void add(int u, int v, lint w)
    {
        e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
        e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
    }

    void dfs(int u, int fa)
    {
        dep[u] = dep[fa] + 1;
        w[u] += dis[u];
        ord[dfn[u] = ++Time] = u; Min[0][Time] = u;
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa) {
                dis[v] = dis[u] + e[i].w;
                dfs(v, u);
                ord[++Time] = u; Min[0][Time] = u;
            }
    }

    void prest()
    {
        for (int j = 0; (1 << (j + 1)) <= Time; ++j)
            for (int i = 1; i + (1 << (j + 1)) - 1 <= Time; ++i) {
                if (dep[Min[j][i]] <= dep[Min[j][i + (1 << j)]]) Min[j + 1][i] = Min[j][i];
                else Min[j + 1][i] = Min[j][i + (1 << j)];
            }
    }

    inline int lca(int u, int v) 
    {
        static int k;
        u = dfn[u]; v = dfn[v];
        if (u > v) swap(u, v);
        k = Log[v - u + 1];
        if (dep[Min[k][u]] <= dep[Min[k][v - (1 << k) + 1]]) return Min[k][u];
        else return Min[k][v - (1 << k) + 1];
    }
    
    inline lint t1dis(int a, int b) {return w[a] + w[b] - (dis[lca(a, b)] << 1);}

}

namespace t2
{

    struct edge 
    {
        int to, next;
        lint w;
    } e[maxn * 2];
    int h[maxn], tot;
    int dep[maxn], Time, ord[maxn], dfn[maxn], Min[21][maxn];
	lint dis[maxn];

    inline void add(int u, int v, lint w)
    {
        e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
        e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
    }

    void dfs(int u, int fa)
    {
        dep[u] = dep[fa] + 1;
        w[u] += dis[u];
        ord[dfn[u] = ++Time] = u; Min[0][Time] = u;
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa) {
                dis[v] = dis[u] + e[i].w;
                dfs(v, u);
                ord[++Time] = u; Min[0][Time] = u;
            }
    }

    void prest()
    {
        for (int j = 0; (1 << (j + 1)) <= Time; ++j)
            for (int i = 1; i + (1 << (j + 1)) - 1 <= Time; ++i) {
                if (dep[Min[j][i]] <= dep[Min[j][i + (1 << j)]]) Min[j + 1][i] = Min[j][i];
                else Min[j + 1][i] = Min[j][i + (1 << j)];
            }
    }

    inline int lca(int u, int v) 
    {
        static int k;
        u = dfn[u]; v = dfn[v];
        if (u > v) swap(u, v);
        k = Log[v - u + 1];
        if (dep[Min[k][u]] <= dep[Min[k][v - (1 << k) + 1]]) return Min[k][u];
        else return Min[k][v - (1 << k) + 1];
    }

    namespace virtual_tree
    {
    
        struct edge 
        {
            int to, next;
        } e[maxn];
        int h[maxn], tot, x[maxn], in[maxn], cnt, stk[maxn], top;

        struct node 
        {
            int u, v;
            lint dis;

            node() {u = v = dis = 0;}
            node(int a, int b) {u = a; v = b; dis = t1::t1dis(a, b);}
            node(int a, int b, lint d) {u = a; v = b; dis = d;}
            bool operator < (const node &a) const {return dis < a.dis;}
            friend node operator + (node a, node b) {
                if (a.u == 0) return b;
                if (b.u == 0) return a;
                node res = max(a, b);
                res = max(res, max(node(a.u, b.u), node(a.v, b.v)));
                res = max(res, max(node(a.u, b.v), node(a.v, b.u)));
                return res;
            }
        } f[maxn][2];

        inline lint merge(node a, node b)
        {
            if (a.u == 0 || b.u == 0) return 0;
            return max(max(t1::t1dis(a.u, b.u), t1::t1dis(a.v, b.v)), max(t1::t1dis(a.u, b.v), t1::t1dis(a.v, b.u)));
        }

        void add(int u, int v)
        {
            e[++tot] = (edge) {v, h[u]}; h[u] = tot;
            e[++tot] = (edge) {u, h[v]}; h[v] = tot;
        }

        inline bool cmp(const int &a, const int &b)
        {
            return dfn[a] < dfn[b];
        }

        void dfs(int u, int fa)
        {
            if (in[u] == 1) f[u][typ[u]] = node(u, u), f[u][typ[u] ^ 1] = node();
            else f[u][0] = f[u][1] = node();
            for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
                if (v != fa) {
					dfs(v, u);
                    ans = max(ans, max(merge(f[u][0], f[v][1]), merge(f[u][1], f[v][0])) + k - (dis[u] << 1));
                    f[u][0] = f[u][0] + f[v][0];
                    f[u][1] = f[u][1] + f[v][1];
                }
            h[u] = 0;
        }

        void build()
        {
			cnt = 0;
            for (int i = 1; i <= cnt1; ++i) x[++cnt] = q1[i];
            for (int i = 1; i <= cnt2; ++i) x[++cnt] = q2[i];
            sort(x + 1, x + cnt + 1, cmp);
            for (int i = 1; i <= cnt; ++i) in[x[i]] = 1;

           	tot = 0; stk[top = 1] = 1;
	        for (int i = 1; i <= cnt; ++i) {
		        int p = lca(stk[top], x[i]);
                if (stk[top] != p) {
                    while (dfn[p] <= dfn[stk[top - 1]]) add(stk[top - 1], stk[top]), --top;
                    if (stk[top] != p) add(p, stk[top]), stk[top] = p;
                }
                if (stk[top] != x[i]) stk[++top] = x[i];
            }
            --top;
            while (top) add(stk[top], stk[top + 1]), --top;

			dfs(1, 0);

            for (int i = 1; i <= cnt; ++i) in[x[i]] = 0;
        }
    }

}

namespace t3
{

    struct edge 
    {
        int to, next;
        lint w;
        bool cut;
    } e[maxn * 2];
    int h[maxn], tot;
    int all, siz[maxn], res, rev;
    lint dis[maxn];

    inline void add(int u, int v, lint w)
    {
        e[++tot] = (edge) {v, h[u], w, 0}; h[u] = tot;
        e[++tot] = (edge) {u, h[v], w, 0}; h[v] = tot;
    }

    namespace original_tree
    {
        
        struct edge 
        {
            int to, next;
            lint w;
        } e[maxn * 2];
        int h[maxn], tot, now[maxn], cnt;

        inline void add(int u, int v, lint w)
        {
            e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
            e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
        }

        inline void link(int u, int v, lint w)
        {
            ++cnt;
            t3::add(cnt, v, w);
            t3::add(now[u], cnt, 0);
            now[u] = cnt;
        }

        void dfs(int u, int fa)
        {
            for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
                if (v != fa) {
                    link(u, v, e[i].w);
                    dfs(v, u);
                }
        }

        void build()
        {
            for (int i = 1; i <= n; ++i) now[i] = i;
            cnt = n;
            dfs(1, 0);
        }

    }

    void dfs1(int u, int fa)
    {
        siz[u] = 1;
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa && !e[i].cut) dfs1(v, u), siz[u] += siz[v];;
    }

    void dfs2(int u, int fa)
    {
        static int mxsz;
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa && !e[i].cut) {
                mxsz = max(siz[v], all - siz[v]);
                if (rev > mxsz) res = i, rev = mxsz;
                dfs2(v, u);
            }
    }    

    void dfs3(int u, int fa, int f, int &cnt, int *q)
    {
        if (u <= n) typ[u] = f, q[++cnt] = u;
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa && !e[i].cut)
                dis[v] = dis[u] + e[i].w, dfs3(v, u, f, cnt, q);
    }

    void solve(int u)
    {
        dfs1(u, 0);
        if (siz[u] == 1) return ;
        res = 0; rev = 1 << 30; all = siz[u]; 
        dfs2(u, 0);

        int st = e[res].to, ed = e[res ^ 1].to;
        e[res].cut = e[res ^ 1].cut = 1;
        cnt1 = cnt2 = dis[st] = dis[ed] = 0;
        dfs3(st, 0, 0, cnt1, q1); dfs3(ed, 0, 1, cnt2, q2);

        for (int i = 1; i <= cnt1; ++i) w[q1[i]] += dis[q1[i]];
        for (int i = 1; i <= cnt2; ++i) w[q2[i]] += dis[q2[i]];
        k = e[res].w; t2::virtual_tree::build();
        for (int i = 1; i <= cnt1; ++i) w[q1[i]] -= dis[q1[i]];
        for (int i = 1; i <= cnt2; ++i) w[q2[i]] -= dis[q2[i]];

        solve(st);
		solve(ed);
    }

}

int main()
{
    n = gi();
    for (int u, v, i = 1; i < n; ++i) u = gi(), v = gi(), t1::add(u, v, gi());
    for (int u, v, i = 1; i < n; ++i) u = gi(), v = gi(), t2::add(u, v, gi());
    for (int u, v, i = 1; i < n; ++i) u = gi(), v = gi(), t3::original_tree::add(u, v, gi());

    for (int i = 2; i <= n * 2; ++i) Log[i] = Log[i >> 1] + 1;

    t1::dfs(1, 0);
    t1::prest();
    t2::dfs(1, 0);
    t2::prest();
	t3::tot = 1;
    t3::original_tree::build();
    t3::solve(1);

	printf("%lld\n", ans);
	
    return 0;
}

猜你喜欢

转载自blog.csdn.net/DSL_HN_2002/article/details/84980962