Luogu4220 WC2018 通道 边分治、虚树

传送门


毒瘤数据结构题qwq

设三棵树分别为$T1,T2,T3$

首先将$T1$边分治,具体步骤如下:

①多叉树->二叉树,具体操作是对于每一个父亲,建立与儿子个数相同的虚点,将父亲与这些虚点穿成一条链(父亲在链顶),在虚点的另一边接上儿子,之前父亲到儿子的边权移动到虚点到这个儿子的边上。代码长下面这样

    void rebuild(int x , int f){
        int pre = ++cntNode , p = x;//pre是当前虚点的编号
        for(int i = Thead[x] ; i ; i = TEd[i].upEd)
            if(TEd[i].end != f){//连接上一个虚点、当前虚点和一个儿子。
                addEd(Ed , head , cntEd , x , pre , 0);
                addEd(Ed , head , cntEd , pre , x , 0);
                addEd(Ed , head , cntEd , pre , TEd[i].end , TEd[i].w);
                addEd(Ed , head , cntEd , TEd[i].end , pre , TEd[i].w);
                rebuild(TEd[i].end , p);
                x = pre;//替换当前接的虚点
                pre = ++cntNode;
            }
    }

②选择一条边,使得两侧的子树大小差距尽可能小

③计算经过这条边的路径的答案

④分治旁边的两个子树

至于为什么不用点分治……等会儿再提

设我们分治的边为$(s,t)$,切掉这条边之后的两棵子树分别为$L,R$,设$dis1_i$表示$i$到$(s,t)$边的距离,$dis_{Tx}(i,j)$表示在第$x$棵树上$i$到$j$的距离

那么我们在③步骤中需要求的就是满足$i \in L , j \in R$的$$ans=dis_1i+dis_1j+w(s,t)+dis_{T2}(i,j)+dis_{T3}(i,j)$$的最大值

稍微魔改一下式子,设在$T2$上根到$i$的路径上所有边的边权和为$dis_2i$

那么$ans=dis_1i+dis_1j+dis_2i+dis_2j-2 \times dis_2LCA(i,j)+dis_{T3}(i,j)+w(s,t)$

我们考虑枚举$LCA(i,j)$,那么$dis_2LCA(i,j)$与$w(s,t)$对答案最大值就不会产生影响了,会产生影响的部分就是

$$delta = dis_1i+dis_1j+dis_2i+dis_2j+dis_{T3}(i,j)$$

我们把$dis_1i+dis_2i$看做在$T3$上新增了一个$i'$点,与$i$分属同一子树($L$或$R$),且只与$i$连了一条权值为$dis_1i+dis_2i$的边。我们设新的树为$T3'$,那么我们需要在$T3'$上找到两个点$x,y$,满足$x \in L , y \in R$且$dis_{T3'}(x,y)$最大

到这里我们需要一个结论帮助:

如果点集$A$中最长路的两个端点为$u,v$,点集$B$中最长路的两个端点为$x,y$,则跨过$A,B$的最长路的端点可能的组合只有$(u,x)(u,y)(v,x)(v,y)$

这意味着我们可以使用树形$DP$来维护$T2$中的一个子树内所有点在$T3'$上的最长路对应的两个端点。

考虑在$T2$上对$L \cup R$的点建立虚树。设$f_{i,0/1}$表示在$T2$树上$i$的子树中,且属于$T1$中$L/R$子树,在$T3'$上取到最长路的两个点,合并直接把四个端点拿出来一一在$T3$上算路径长度取$max$即可。然后我们考虑虚树上每一个点的答案贡献。我们在合并答案的时候进行计算,也就是说在合并$f_{i,0/1}$和$f_{son_i,0/1}$的时候,将$f_{i,0}$与$f_{son_i,1}$拿出来跑$T3'$上最大值,再将$f_{i,1}$与$f_{son_i,0}$拿出来跑$T3'$上最大值,获得的总的最大值就是$delta = dis_1i+dis_1j+dis_2i+dis_2j+dis_{T3}(i,j)$的最大值了。而这个时候两个端点的$LCA$必定是$i$,再减掉$2 \times dis2_i$、加上$w(s,t)$就能够得到$ans$了。

那么为什么不使用点分治也就很明了了。边分治必定将原树分成两个子树,但是点分治可能会分成很多,很多的子树之间的合并会很麻烦,所以在这种方法中点分治不能用……

代码巨长……8.4K史上最长代码

#include<bits/stdc++.h>
#define int long long
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    while(c != EOF && !isdigit(c))
        c = getchar();
    while(c != EOF && isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return a;
}

const int MAXN = 200010;
struct Edge{
    int end , upEd , w;
};
int N , all , logg2[MAXN << 1];

namespace Tree3{
//最长路
    Edge Ed[MAXN << 1];
    int head[MAXN] , dep[MAXN] , val[MAXN] , len[MAXN] , ST[21][MAXN << 1] , fir[MAXN];
    int cntST , cntEd;

    inline void addEd(int a , int b , int c){
        Ed[++cntEd].end = b;
        Ed[cntEd].upEd = head[a];
        Ed[cntEd].w = c;
        head[a] = cntEd;
    }

    void dfs(int x , int p , int l){
        dep[x] = dep[p] + 1;
        len[x] = l;
        fir[x] = ++cntST;
        ST[0][cntST] = x;
        for(int i = head[x] ; i ; i = Ed[i].upEd)
            if(Ed[i].end != p){
                dfs(Ed[i].end , x , l + Ed[i].w);
                ST[0][++cntST] = x;
            }
    }

    inline int cmp(int a , int b){
        return dep[a] < dep[b] ? a : b;
    }
    
    void init_st(){
        for(int i = 1 ; 1 << i <= cntST ; ++i)
            for(int j = 1 ; j + (1 << i) - 1 <= cntST ; ++j)
                ST[i][j] = cmp(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]);
    }

    inline int LCA(int x , int y){
        x = fir[x];
        y = fir[y];
        if(y < x)
            swap(x , y);
        int t = logg2[y - x + 1];
        return cmp(ST[t][x] , ST[t][y - (1 << t) + 1]);
    }

    inline int calcLen(int x , int y){
        if(!x || !y)
            return 0;
        return len[x] + len[y] - (len[LCA(x , y)] << 1) + val[x] + val[y];
    }

    inline void input(){
        for(int i = 1 ; i < N ; ++i){
            int a = read() , b = read() , c = read();
            addEd(a , b , c);
            addEd(b , a , c);
        }
        dfs(1 , 0 , 0);
        init_st();
    }
}

namespace Tree2{
//虚树
    Edge Ed[MAXN << 1] , REd[MAXN];
    int head[MAXN] , Rhead[MAXN] , dep[MAXN] , s[MAXN] , ST[21][MAXN << 1] , fir[MAXN] , dfn[MAXN] , len[MAXN] , ans[MAXN][2][2];
    int cntEd , cntREd , ts , cntST , root , headS;
    vector < int > v;
    
    inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b , int c = 0){
        Ed[++cntEd].end = b;
        Ed[cntEd].upEd = head[a];
        Ed[cntEd].w = c;
        head[a] = cntEd;
    }

    void dfs(int x , int p , int l){
        len[x] = l;
        dep[x] = dep[p] + 1;
        dfn[x] = ++ts;
        fir[x] = ++cntST;
        ST[0][cntST] = x;
        for(int i = head[x] ; i ; i = Ed[i].upEd)
            if(Ed[i].end != p){
                dfs(Ed[i].end , x , l + Ed[i].w);
                ST[0][++cntST] = x;
            }
    }

    inline int cmp(int a , int b){
        return dep[a] < dep[b] ? a : b;
    }
    
    void init_st(){
        for(int i = 1 ; 1 << i <= cntST ; ++i)
            for(int j = 1 ; j + (1 << i) - 1 <= cntST ; ++j)
                ST[i][j] = cmp(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]);
    }

    inline int LCA(int x , int y){
        x = fir[x];
        y = fir[y];
        if(y < x)
            swap(x , y);
        int t = logg2[y - x + 1];
        return cmp(ST[t][x] , ST[t][y - (1 << t) + 1]);
    }
    
    inline void input(){
        for(int i = 1 ; i < N ; ++i){
            int a = read() , b = read() , c = read();
            addEd(Ed , head , cntEd , a , b , c);
            addEd(Ed , head , cntEd , b , a , c);
        }
        dfs(1 , 0 , 0);
        init_st();
    }

    bool c(int a , int b){
        return dfn[a] < dfn[b];
    }

    inline void maintain(int x , int y){
        int temp[4];
        temp[0] = ans[x][0][0];
        temp[1] = ans[x][0][1];
        temp[2] = ans[y][0][0];
        temp[3] = ans[y][0][1];
        sort(temp , temp + 4);
        if(!temp[2])
            ans[x][0][0] = temp[3];
        else{
            int maxN = 0;
            for(int i = 3 ; i >= 0 ; --i)
                for(int j = i - 1 ; j >= 0 ; --j)
                    if(Tree3::calcLen(temp[i] , temp[j]) > maxN){
                        maxN = Tree3::calcLen(temp[i] , temp[j]);
                        ans[x][0][0] = temp[i];
                        ans[x][0][1] = temp[j];
                    }
        }
        temp[0] = ans[x][1][0];
        temp[1] = ans[x][1][1];
        temp[2] = ans[y][1][0];
        temp[3] = ans[y][1][1];
        sort(temp , temp + 4);
        if(!temp[2])
            ans[x][1][0] = temp[3];
        else{
            int maxN = 0;
            for(int i = 3 ; i >= 0 ; --i)
                for(int j = i - 1 ; j >= 0 ; --j)
                    if(Tree3::calcLen(temp[i] , temp[j]) > maxN){
                        maxN = Tree3::calcLen(temp[i] , temp[j]);
                        ans[x][1][0] = temp[i];
                        ans[x][1][1] = temp[j];
                    }
        }
    }
    
    void dp(int x , int l){
        for(int& i = Rhead[x] ; i ; i = REd[i].upEd){
            dp(REd[i].end , l);
            for(int j = 0 ; j <= 1 ; ++j)
                for(int k = 0 ; k <= 1 ; ++k)
                    all = max(all , max(Tree3::calcLen(ans[x][0][j] , ans[REd[i].end][1][k]) , Tree3::calcLen(ans[x][1][j] , ans[REd[i].end][0][k])) - 2 * len[x] + l);
            maintain(x , REd[i].end);
            ans[REd[i].end][0][0] = ans[REd[i].end][0][1] = ans[REd[i].end][1][0] = ans[REd[i].end][1][1] = 0;
        }
    }
    
    inline void solve(const vector < int >& v1 , const vector < int >& v2 , int l){
        v.clear();
        cntREd = 0;
        for(int i = 0 ; i < v1.size() ; ++i)
            ans[v1[i]][0][0] = v1[i];
        for(int i = 0 ; i < v2.size() ; ++i)
            ans[v2[i]][1][0] = v2[i];
        v.insert(v.end() , v1.begin() , v1.end());
        v.insert(v.end() , v2.begin() , v2.end());
        sort(v.begin() , v.end() , c);
        for(int i = 0 ; i < v.size() ; ++i){
            Tree3::val[v[i]] += len[v[i]];
            if(!headS)
                s[++headS] = v[i];
            else{
                int t = LCA(s[headS] , v[i]);
                if(dep[s[headS]] > dep[t]){
                    while(dep[s[headS - 1]] > dep[t]){
                        addEd(REd , Rhead , cntREd , s[headS - 1] , s[headS]);
                        --headS;
                    }
                    addEd(REd , Rhead , cntREd , t , s[headS]);
                    if(s[--headS] != t)
                        s[++headS] = t;
                }
                s[++headS] = v[i];
            }
        }
        while(headS - 1){
            addEd(REd , Rhead , cntREd , s[headS - 1] , s[headS]);
            --headS;
        }
        root = s[headS--];
        dp(root , l);
        ans[root][0][0] = ans[root][1][0] = ans[root][0][1] = ans[root][1][1] = 0;
        for(int i = 0 ; i < v.size() ; ++i)
            Tree3::val[v[i]] -= len[v[i]];
    }
    
}

namespace Tree1{
//边分治
    Edge Ed[MAXN << 2] , TEd[MAXN << 1];
    int head[MAXN << 1] , Thead[MAXN] , size[MAXN << 1] , dis[MAXN << 1];
    int cntEd = 1 , cntTEd , cntNode , nowSize , minSize , minInd;
    bool vis[MAXN << 1];
    vector < int > v1 , v2;
    
    inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b , int c){
        Ed[++cntEd].end = b;
        Ed[cntEd].upEd = head[a];
        Ed[cntEd].w = c;
        head[a] = cntEd;
    }

    void rebuild(int x , int f){
        int pre = ++cntNode , p = x;
        for(int i = Thead[x] ; i ; i = TEd[i].upEd)
            if(TEd[i].end != f){
                addEd(Ed , head , cntEd , x , pre , 0);
                addEd(Ed , head , cntEd , pre , x , 0);
                addEd(Ed , head , cntEd , pre , TEd[i].end , TEd[i].w);
                addEd(Ed , head , cntEd , TEd[i].end , pre , TEd[i].w);
                rebuild(TEd[i].end , p);
                x = pre;
                pre = ++cntNode;
            }
    }
    
    inline void input(){
        cntNode = N;
        for(int i = 1 ; i < N ; ++i){
            int a = read() , b = read() , c = read();
            addEd(TEd , Thead , cntTEd , a , b , c);
            addEd(TEd , Thead , cntTEd , b , a , c);
        }
        rebuild(1 , 0);
    }

    void getSize(int x){
        vis[x] = 1;
        ++nowSize;
        for(int i = head[x] ; i ; i = Ed[i].upEd)
            if(Ed[i].end != -1 && !vis[Ed[i].end])
                getSize(Ed[i].end);
        vis[x] = 0;
    }

    void getRoot(int x){
        vis[x] = size[x] = 1;
        for(int i = head[x] ; i ; i = Ed[i].upEd)
            if(Ed[i].end != -1 && !vis[Ed[i].end]){
                getRoot(Ed[i].end);
                if(max(size[Ed[i].end] , nowSize - size[Ed[i].end]) < minSize){
                    minSize = max(size[Ed[i].end] , nowSize - size[Ed[i].end]);
                    minInd = i;
                }
                size[x] += size[Ed[i].end];
            }
        vis[x] = 0;
    }

    void get(int x , int l , vector < int > &v){
        dis[x] = l;
        vis[x] = 1;
        if(x <= N)
            v.push_back(x);
        for(int i = head[x] ; i ; i = Ed[i].upEd)
            if(Ed[i].end != -1 && !vis[Ed[i].end])
                get(Ed[i].end , Ed[i].w + l , v);
        vis[x] = 0;
    }
        
    void bfz(int x){
        nowSize = 0;
        minSize = 0x7fffffff;
        v1.clear();
        v2.clear();
        getSize(x);
        if(nowSize == 1)
            return;
        getRoot(x);
        int L = Ed[minInd].end , R = Ed[minInd ^ 1].end;
        Ed[minInd].end = Ed[minInd ^ 1].end = -1;
        get(L , 0 , v1);
        get(R , 0 , v2);
        if(!v1.empty() && !v2.empty()){
            for(int i = 0 ; i < v1.size() ; ++i)
                Tree3::val[v1[i]] += dis[v1[i]];
            for(int i = 0 ; i < v2.size() ; ++i)
                Tree3::val[v2[i]] += dis[v2[i]];
            Tree2::solve(v1 , v2 , Ed[minInd].w);
            for(int i = 0 ; i < v1.size() ; ++i)
                Tree3::val[v1[i]] -= dis[v1[i]];
            for(int i = 0 ; i < v2.size() ; ++i)
                Tree3::val[v2[i]] -= dis[v2[i]];
        }
        bool f = v2.empty();
        if(!v1.empty())
            bfz(L);
        if(!f)
            bfz(R);
    }
    
    inline void work(){
        bfz(1);
    }
}

signed main(){
    N = read();
    for(int i = 2 ; i <= N << 1 ; ++i)
        logg2[i] = logg2[i >> 1] + 1;
    Tree1::input();
    Tree2::input();
    Tree3::input();
    Tree1::work();
    cout << all;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Itst/p/10079400.html