毒瘤数据结构题qwq
设三棵树分别为$T1,T2,T3$
首先将$T1$边分治,具体步骤如下:
①多叉树->二叉树,具体操作是对于每一个父亲,建立与儿子个数相同的虚点,将父亲与这些虚点穿成一条链(父亲在链顶),在虚点的另一边接上儿子,之前父亲到儿子的边权移动到虚点到这个儿子的边上。代码长下面这样
void rebuild(int x , int f){
int pre = ++cntNode , p = x;//pre是当前虚点的编号
for(int i = Thead[x] ; i ; i = TEd[i].upEd)
if(TEd[i].end != f){//连接上一个虚点、当前虚点和一个儿子。
addEd(Ed , head , cntEd , x , pre , 0);
addEd(Ed , head , cntEd , pre , x , 0);
addEd(Ed , head , cntEd , pre , TEd[i].end , TEd[i].w);
addEd(Ed , head , cntEd , TEd[i].end , pre , TEd[i].w);
rebuild(TEd[i].end , p);
x = pre;//替换当前接的虚点
pre = ++cntNode;
}
}
②选择一条边,使得两侧的子树大小差距尽可能小
③计算经过这条边的路径的答案
④分治旁边的两个子树
至于为什么不用点分治……等会儿再提
设我们分治的边为$(s,t)$,切掉这条边之后的两棵子树分别为$L,R$,设$dis1_i$表示$i$到$(s,t)$边的距离,$dis_{Tx}(i,j)$表示在第$x$棵树上$i$到$j$的距离
那么我们在③步骤中需要求的就是满足$i \in L , j \in R$的$$ans=dis_1i+dis_1j+w(s,t)+dis_{T2}(i,j)+dis_{T3}(i,j)$$的最大值
稍微魔改一下式子,设在$T2$上根到$i$的路径上所有边的边权和为$dis_2i$
那么$ans=dis_1i+dis_1j+dis_2i+dis_2j-2 \times dis_2LCA(i,j)+dis_{T3}(i,j)+w(s,t)$
我们考虑枚举$LCA(i,j)$,那么$dis_2LCA(i,j)$与$w(s,t)$对答案最大值就不会产生影响了,会产生影响的部分就是
$$delta = dis_1i+dis_1j+dis_2i+dis_2j+dis_{T3}(i,j)$$
我们把$dis_1i+dis_2i$看做在$T3$上新增了一个$i'$点,与$i$分属同一子树($L$或$R$),且只与$i$连了一条权值为$dis_1i+dis_2i$的边。我们设新的树为$T3'$,那么我们需要在$T3'$上找到两个点$x,y$,满足$x \in L , y \in R$且$dis_{T3'}(x,y)$最大
到这里我们需要一个结论帮助:
如果点集$A$中最长路的两个端点为$u,v$,点集$B$中最长路的两个端点为$x,y$,则跨过$A,B$的最长路的端点可能的组合只有$(u,x)(u,y)(v,x)(v,y)$
这意味着我们可以使用树形$DP$来维护$T2$中的一个子树内所有点在$T3'$上的最长路对应的两个端点。
考虑在$T2$上对$L \cup R$的点建立虚树。设$f_{i,0/1}$表示在$T2$树上$i$的子树中,且属于$T1$中$L/R$子树,在$T3'$上取到最长路的两个点,合并直接把四个端点拿出来一一在$T3$上算路径长度取$max$即可。然后我们考虑虚树上每一个点的答案贡献。我们在合并答案的时候进行计算,也就是说在合并$f_{i,0/1}$和$f_{son_i,0/1}$的时候,将$f_{i,0}$与$f_{son_i,1}$拿出来跑$T3'$上最大值,再将$f_{i,1}$与$f_{son_i,0}$拿出来跑$T3'$上最大值,获得的总的最大值就是$delta = dis_1i+dis_1j+dis_2i+dis_2j+dis_{T3}(i,j)$的最大值了。而这个时候两个端点的$LCA$必定是$i$,再减掉$2 \times dis2_i$、加上$w(s,t)$就能够得到$ans$了。
那么为什么不使用点分治也就很明了了。边分治必定将原树分成两个子树,但是点分治可能会分成很多,很多的子树之间的合并会很麻烦,所以在这种方法中点分治不能用……
代码巨长……8.4K史上最长代码
#include<bits/stdc++.h> #define int long long //This code is written by Itst using namespace std; inline int read(){ int a = 0; char c = getchar(); while(c != EOF && !isdigit(c)) c = getchar(); while(c != EOF && isdigit(c)){ a = (a << 3) + (a << 1) + (c ^ '0'); c = getchar(); } return a; } const int MAXN = 200010; struct Edge{ int end , upEd , w; }; int N , all , logg2[MAXN << 1]; namespace Tree3{ //最长路 Edge Ed[MAXN << 1]; int head[MAXN] , dep[MAXN] , val[MAXN] , len[MAXN] , ST[21][MAXN << 1] , fir[MAXN]; int cntST , cntEd; inline void addEd(int a , int b , int c){ Ed[++cntEd].end = b; Ed[cntEd].upEd = head[a]; Ed[cntEd].w = c; head[a] = cntEd; } void dfs(int x , int p , int l){ dep[x] = dep[p] + 1; len[x] = l; fir[x] = ++cntST; ST[0][cntST] = x; for(int i = head[x] ; i ; i = Ed[i].upEd) if(Ed[i].end != p){ dfs(Ed[i].end , x , l + Ed[i].w); ST[0][++cntST] = x; } } inline int cmp(int a , int b){ return dep[a] < dep[b] ? a : b; } void init_st(){ for(int i = 1 ; 1 << i <= cntST ; ++i) for(int j = 1 ; j + (1 << i) - 1 <= cntST ; ++j) ST[i][j] = cmp(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]); } inline int LCA(int x , int y){ x = fir[x]; y = fir[y]; if(y < x) swap(x , y); int t = logg2[y - x + 1]; return cmp(ST[t][x] , ST[t][y - (1 << t) + 1]); } inline int calcLen(int x , int y){ if(!x || !y) return 0; return len[x] + len[y] - (len[LCA(x , y)] << 1) + val[x] + val[y]; } inline void input(){ for(int i = 1 ; i < N ; ++i){ int a = read() , b = read() , c = read(); addEd(a , b , c); addEd(b , a , c); } dfs(1 , 0 , 0); init_st(); } } namespace Tree2{ //虚树 Edge Ed[MAXN << 1] , REd[MAXN]; int head[MAXN] , Rhead[MAXN] , dep[MAXN] , s[MAXN] , ST[21][MAXN << 1] , fir[MAXN] , dfn[MAXN] , len[MAXN] , ans[MAXN][2][2]; int cntEd , cntREd , ts , cntST , root , headS; vector < int > v; inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b , int c = 0){ Ed[++cntEd].end = b; Ed[cntEd].upEd = head[a]; Ed[cntEd].w = c; head[a] = cntEd; } void dfs(int x , int p , int l){ len[x] = l; dep[x] = dep[p] + 1; dfn[x] = ++ts; fir[x] = ++cntST; ST[0][cntST] = x; for(int i = head[x] ; i ; i = Ed[i].upEd) if(Ed[i].end != p){ dfs(Ed[i].end , x , l + Ed[i].w); ST[0][++cntST] = x; } } inline int cmp(int a , int b){ return dep[a] < dep[b] ? a : b; } void init_st(){ for(int i = 1 ; 1 << i <= cntST ; ++i) for(int j = 1 ; j + (1 << i) - 1 <= cntST ; ++j) ST[i][j] = cmp(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]); } inline int LCA(int x , int y){ x = fir[x]; y = fir[y]; if(y < x) swap(x , y); int t = logg2[y - x + 1]; return cmp(ST[t][x] , ST[t][y - (1 << t) + 1]); } inline void input(){ for(int i = 1 ; i < N ; ++i){ int a = read() , b = read() , c = read(); addEd(Ed , head , cntEd , a , b , c); addEd(Ed , head , cntEd , b , a , c); } dfs(1 , 0 , 0); init_st(); } bool c(int a , int b){ return dfn[a] < dfn[b]; } inline void maintain(int x , int y){ int temp[4]; temp[0] = ans[x][0][0]; temp[1] = ans[x][0][1]; temp[2] = ans[y][0][0]; temp[3] = ans[y][0][1]; sort(temp , temp + 4); if(!temp[2]) ans[x][0][0] = temp[3]; else{ int maxN = 0; for(int i = 3 ; i >= 0 ; --i) for(int j = i - 1 ; j >= 0 ; --j) if(Tree3::calcLen(temp[i] , temp[j]) > maxN){ maxN = Tree3::calcLen(temp[i] , temp[j]); ans[x][0][0] = temp[i]; ans[x][0][1] = temp[j]; } } temp[0] = ans[x][1][0]; temp[1] = ans[x][1][1]; temp[2] = ans[y][1][0]; temp[3] = ans[y][1][1]; sort(temp , temp + 4); if(!temp[2]) ans[x][1][0] = temp[3]; else{ int maxN = 0; for(int i = 3 ; i >= 0 ; --i) for(int j = i - 1 ; j >= 0 ; --j) if(Tree3::calcLen(temp[i] , temp[j]) > maxN){ maxN = Tree3::calcLen(temp[i] , temp[j]); ans[x][1][0] = temp[i]; ans[x][1][1] = temp[j]; } } } void dp(int x , int l){ for(int& i = Rhead[x] ; i ; i = REd[i].upEd){ dp(REd[i].end , l); for(int j = 0 ; j <= 1 ; ++j) for(int k = 0 ; k <= 1 ; ++k) all = max(all , max(Tree3::calcLen(ans[x][0][j] , ans[REd[i].end][1][k]) , Tree3::calcLen(ans[x][1][j] , ans[REd[i].end][0][k])) - 2 * len[x] + l); maintain(x , REd[i].end); ans[REd[i].end][0][0] = ans[REd[i].end][0][1] = ans[REd[i].end][1][0] = ans[REd[i].end][1][1] = 0; } } inline void solve(const vector < int >& v1 , const vector < int >& v2 , int l){ v.clear(); cntREd = 0; for(int i = 0 ; i < v1.size() ; ++i) ans[v1[i]][0][0] = v1[i]; for(int i = 0 ; i < v2.size() ; ++i) ans[v2[i]][1][0] = v2[i]; v.insert(v.end() , v1.begin() , v1.end()); v.insert(v.end() , v2.begin() , v2.end()); sort(v.begin() , v.end() , c); for(int i = 0 ; i < v.size() ; ++i){ Tree3::val[v[i]] += len[v[i]]; if(!headS) s[++headS] = v[i]; else{ int t = LCA(s[headS] , v[i]); if(dep[s[headS]] > dep[t]){ while(dep[s[headS - 1]] > dep[t]){ addEd(REd , Rhead , cntREd , s[headS - 1] , s[headS]); --headS; } addEd(REd , Rhead , cntREd , t , s[headS]); if(s[--headS] != t) s[++headS] = t; } s[++headS] = v[i]; } } while(headS - 1){ addEd(REd , Rhead , cntREd , s[headS - 1] , s[headS]); --headS; } root = s[headS--]; dp(root , l); ans[root][0][0] = ans[root][1][0] = ans[root][0][1] = ans[root][1][1] = 0; for(int i = 0 ; i < v.size() ; ++i) Tree3::val[v[i]] -= len[v[i]]; } } namespace Tree1{ //边分治 Edge Ed[MAXN << 2] , TEd[MAXN << 1]; int head[MAXN << 1] , Thead[MAXN] , size[MAXN << 1] , dis[MAXN << 1]; int cntEd = 1 , cntTEd , cntNode , nowSize , minSize , minInd; bool vis[MAXN << 1]; vector < int > v1 , v2; inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b , int c){ Ed[++cntEd].end = b; Ed[cntEd].upEd = head[a]; Ed[cntEd].w = c; head[a] = cntEd; } void rebuild(int x , int f){ int pre = ++cntNode , p = x; for(int i = Thead[x] ; i ; i = TEd[i].upEd) if(TEd[i].end != f){ addEd(Ed , head , cntEd , x , pre , 0); addEd(Ed , head , cntEd , pre , x , 0); addEd(Ed , head , cntEd , pre , TEd[i].end , TEd[i].w); addEd(Ed , head , cntEd , TEd[i].end , pre , TEd[i].w); rebuild(TEd[i].end , p); x = pre; pre = ++cntNode; } } inline void input(){ cntNode = N; for(int i = 1 ; i < N ; ++i){ int a = read() , b = read() , c = read(); addEd(TEd , Thead , cntTEd , a , b , c); addEd(TEd , Thead , cntTEd , b , a , c); } rebuild(1 , 0); } void getSize(int x){ vis[x] = 1; ++nowSize; for(int i = head[x] ; i ; i = Ed[i].upEd) if(Ed[i].end != -1 && !vis[Ed[i].end]) getSize(Ed[i].end); vis[x] = 0; } void getRoot(int x){ vis[x] = size[x] = 1; for(int i = head[x] ; i ; i = Ed[i].upEd) if(Ed[i].end != -1 && !vis[Ed[i].end]){ getRoot(Ed[i].end); if(max(size[Ed[i].end] , nowSize - size[Ed[i].end]) < minSize){ minSize = max(size[Ed[i].end] , nowSize - size[Ed[i].end]); minInd = i; } size[x] += size[Ed[i].end]; } vis[x] = 0; } void get(int x , int l , vector < int > &v){ dis[x] = l; vis[x] = 1; if(x <= N) v.push_back(x); for(int i = head[x] ; i ; i = Ed[i].upEd) if(Ed[i].end != -1 && !vis[Ed[i].end]) get(Ed[i].end , Ed[i].w + l , v); vis[x] = 0; } void bfz(int x){ nowSize = 0; minSize = 0x7fffffff; v1.clear(); v2.clear(); getSize(x); if(nowSize == 1) return; getRoot(x); int L = Ed[minInd].end , R = Ed[minInd ^ 1].end; Ed[minInd].end = Ed[minInd ^ 1].end = -1; get(L , 0 , v1); get(R , 0 , v2); if(!v1.empty() && !v2.empty()){ for(int i = 0 ; i < v1.size() ; ++i) Tree3::val[v1[i]] += dis[v1[i]]; for(int i = 0 ; i < v2.size() ; ++i) Tree3::val[v2[i]] += dis[v2[i]]; Tree2::solve(v1 , v2 , Ed[minInd].w); for(int i = 0 ; i < v1.size() ; ++i) Tree3::val[v1[i]] -= dis[v1[i]]; for(int i = 0 ; i < v2.size() ; ++i) Tree3::val[v2[i]] -= dis[v2[i]]; } bool f = v2.empty(); if(!v1.empty()) bfz(L); if(!f) bfz(R); } inline void work(){ bfz(1); } } signed main(){ N = read(); for(int i = 2 ; i <= N << 1 ; ++i) logg2[i] = logg2[i >> 1] + 1; Tree1::input(); Tree2::input(); Tree3::input(); Tree1::work(); cout << all; return 0; }