学习虚树,你需要的
前置芝士:树形DP、栈、$dfs$序
先来一道题练练手:寻宝游戏
当然了,这道题和今天的内容基本上没什么关系……
先从一道题开始引入虚树
简略题意就是给出一棵$N$个点树,边有边权,$M$组询问,每一次询问$k_i$个重要点,需要切断一些路径使得所有重要点与$1$号点不连通,求切断路径的权值和的最小值。$\sum k_i \leq 500000 , N \leq 250000 , M \geq 1$(鬼知道这个$M$是什么鬼)
首先如果只有单次询问,很容易设计出一个树形$DP$:设$f_{i,j}$表示切断$i$和其子树中所有重要点的最小代价,转移从儿子转移即可。
这样的复杂度是$O(NM)$的,显然难以跑过。
但是这道题有一个很奇怪的数据范围:所有询问的点数之和$\leq 500000$,我们是否可以设计出一种算法使得复杂度与点数有关呢?
这时候虚树诞生!
虚树处理问题的方式
虚树是处理树形$DP$是一种常用的数据结构,一般我们选取若干重要点,将它们与它们两两之间的$LCA$取出来按照原来的深度关系进行连边,这就是一棵虚树。首先一个结论是:加入一个点最多只会额外加入一个$LCA$(考虑$x,y,z$三个点,如果$dep_{LCA(x,y)}=dep_{LCA(y,z)}$,那么三个点必定共一个$LCA$,否则设$dep_{LCA(x,y)}<dep_{LCA(y,z)}$(不满足就交换$z,x$,那么必定$LCA(x,z)=LCA(x,y)$),所以虚树上点的数量是$2k_i$的。
举个栗子,上面那道题样例的第一个询问中,虚树中就有这些点:$10,6,LCA(10,6)=5,1$。在这个题目里因为所有重要点需要与$1$隔离开,所以$1$也要在虚树内。
然后我们把边连好$(1,5)(5,6)(5,10)$,边权设为两个点在原树的路径上的边的最小权值(实际上直接设深度较深的点到$1$的最短路径也是可行的,但是为了叙述清楚我们先这么讲),我们就可以在这一棵虚树上进行$DP$从而获得与原来的$DP$同样的答案了。
虚树的建立
说了虚树解决树形$DP$问题的原理,那么给出一些点,如何建立出一棵虚树成了我们当前的问题。
我们先处理好原树的$dfs$序,对需要建立虚树的点按照$dfs$序排序,然后我们使用一个栈来维护当前正在建立的虚树的一条链。
我们设栈顶元素为$top$,栈的第二个元素为$sec$
那么新加入一个点$x$的步骤就是:
①如果栈为空,直接入栈
②否则求$p=LCA(x,top)$
③如果$p==top$,则表示$top$是$x$的祖先,直接将$x$入栈(注意这个时候不连边)
④如果$p!=top$,意味着$top$是它所在的链的底端,这个时候我们开始弹栈并连边。
我们考虑$dep_{sec}$与$dep_p$
a.$dep_{sec} > dep_p$,意味着$sec$所在的链也被建立完了,连接边$(sec,top)$并弹出$top$,重回第④步
b.$dep_{sec} == dep_p$,意味着$sec==p$,连接边$(sec,top)$并压入$x$
c.$dep_{sec} < dep_p$,意味着$p$在$top$和$sec$之间,我们连接边$(p,top)$,弹出$top$,压入$p$,最后压入$x$。
所有点都压入栈了以后,栈里还会剩下一条没有连边的链,最后再不断地连接边$(sec,top)$,弹出$top$直到栈中元素只有$1$个,我们的虚树就建好了。
话说没有图好抽象啊……我也不想画图了(懒
给出这道题的代码
1 #include<bits/stdc++.h> 2 #define int long long 3 //This code is written by Itst 4 using namespace std; 5 6 inline int read(){ 7 int a = 0; 8 bool f = 0; 9 char c = getchar(); 10 while(c != EOF && !isdigit(c)){ 11 if(c == '-') 12 f = 1; 13 c = getchar(); 14 } 15 while(c != EOF && isdigit(c)){ 16 a = (a << 3) + (a << 1) + (c ^ '0'); 17 c = getchar(); 18 } 19 return f ? -a : a; 20 } 21 22 const int MAXN = 250010; 23 struct Edge{ 24 int end , upEd , w; 25 }Ed[MAXN << 1] , newEd[MAXN << 1]; 26 int num[MAXN << 1] , head[MAXN] , newHead[MAXN] , dep[MAXN] , s[MAXN] , minN[MAXN] , dfn[MAXN] , jump[MAXN][21]; 27 int headS , N , cnt , cntEd , cntNewEd , ts; 28 29 inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b , int c = 0){ 30 Ed[++cntEd].end = b; 31 Ed[cntEd].upEd = head[a]; 32 Ed[cntEd].w = c; 33 head[a] = cntEd; 34 } 35 36 void dfs(int now , int fa){ 37 jump[now][0] = fa; 38 dep[now] = dep[fa] + 1; 39 dfn[now] = ++ts; 40 for(int i = 1 ; jump[now][i - 1] ; ++i) 41 jump[now][i] = jump[jump[now][i - 1]][i - 1]; 42 for(int i = head[now] ; i ; i = Ed[i].upEd) 43 if(Ed[i].end != fa){ 44 minN[Ed[i].end] = min(minN[now] , Ed[i].w); 45 dfs(Ed[i].end , now); 46 } 47 } 48 49 inline int jumpToLCA(int x , int y){ 50 if(dep[x] < dep[y]) 51 swap(x , y); 52 for(int i = 19 ; i >= 0 ; --i) 53 if(dep[x] - (1 << i) >= dep[y]) 54 x = jump[x][i]; 55 if(x == y) 56 return x; 57 for(int i = 19 ; i >= 0 ; --i) 58 if(jump[x][i] != jump[y][i]){ 59 x = jump[x][i]; 60 y = jump[y][i]; 61 } 62 return jump[x][0]; 63 } 64 65 inline void init(){ 66 cntNewEd = 0; 67 for(int i = 1 ; i <= cnt ; ++i) 68 if(headS == 1) 69 s[++headS] = num[i]; 70 else{ 71 int t = jumpToLCA(s[headS] , num[i]); 72 if(t != s[headS]){ 73 while(dfn[s[headS - 1]] > dfn[t]) 74 addEd(newEd , newHead , cntNewEd , s[--headS] , s[headS]); 75 addEd(newEd , newHead , cntNewEd , t , s[headS--]); 76 if(s[headS] != t) 77 s[++headS] = t; 78 s[++headS] = num[i]; 79 } 80 } 81 while(headS - 1) 82 addEd(newEd , newHead , cntNewEd , s[--headS] , s[headS]); 83 } 84 85 bool cmp(int a , int b){ 86 return dfn[a] < dfn[b]; 87 } 88 89 int dp(int now){ 90 int sum = 0; 91 for(int i = newHead[now] ; i ; i = newEd[i].upEd) 92 sum += dp(newEd[i].end); 93 newHead[now] = 0; 94 return sum ? min(sum , minN[now]) : minN[now]; 95 } 96 97 signed main(){ 98 #ifndef ONLINE_JUDGE 99 freopen("2495.in" , "r" , stdin); 100 //freopen("2495.out" , "w" , stdout); 101 #endif 102 minN[1] = 1ll << 62; 103 s[headS = 1] = 1; 104 N = read(); 105 for(int i = 1 ; i < N ; ++i){ 106 int a = read() , b = read() , c = read(); 107 addEd(Ed , head , cntEd , a , b , c); 108 addEd(Ed , head , cntEd , b , a , c); 109 } 110 dfs(1 , 0); 111 for(int M = read() ; M ; --M){ 112 cnt = read(); 113 for(int i = 1 ; i <= cnt ; ++i) 114 num[i] = read(); 115 sort(num + 1 , num + cnt + 1 , cmp); 116 init(); 117 printf("%lld\n" , dp(1)); 118 } 119 return 0; 120 }
最后给几道练习题:
大工程(DP太显然了直接给代码好了)
1 #include<bits/stdc++.h> 2 #define int long long 3 //This code is written by Itst 4 using namespace std; 5 6 inline int read(){ 7 int a = 0; 8 bool f = 0; 9 char c = getchar(); 10 while(c != EOF && !isdigit(c)){ 11 if(c == '-') 12 f = 1; 13 c = getchar(); 14 } 15 while(c != EOF && isdigit(c)){ 16 a = (a << 3) + (a << 1) + (c ^ '0'); 17 c = getchar(); 18 } 19 return f ? -a : a; 20 } 21 22 const int MAXN = 1000010; 23 struct Edge{ 24 int end , upEd; 25 }Ed[MAXN << 1] , newEd[MAXN]; 26 int jump[MAXN][20] , head[MAXN] , newHead[MAXN] , dep[MAXN] , dfn[MAXN] , s[MAXN] , num[MAXN]; 27 int headS , N , cnt , cntEd , root , cntNewEd , ts , Sum , maxN , minN , sum[MAXN] , size[MAXN] , maxD[MAXN] , minD[MAXN]; 28 29 inline void addEd(Edge* Ed , int* head , int& cntEd , int a , int b){ 30 Ed[++cntEd].end = b; 31 Ed[cntEd].upEd = head[a]; 32 head[a] = cntEd; 33 } 34 35 void init(int x , int fa){ 36 dep[x] = dep[jump[x][0] = fa] + 1; 37 dfn[x] = ++ts; 38 for(int i = 1 ; jump[x][i - 1] ; ++i) 39 jump[x][i] = jump[jump[x][i - 1]][i - 1]; 40 for(int i = head[x] ; i ; i = Ed[i].upEd) 41 if(Ed[i].end != fa) 42 init(Ed[i].end , x); 43 } 44 45 inline int jumpToLCA(int x , int y){ 46 if(dep[x] < dep[y]) 47 swap(x , y); 48 for(int i = 19 ; i >= 0 ; --i) 49 if(dep[x] - (1 << i) >= dep[y]) 50 x = jump[x][i]; 51 if(x == y) 52 return x; 53 for(int i = 19 ; i >= 0 ; --i) 54 if(jump[x][i] != jump[y][i]){ 55 x = jump[x][i]; 56 y = jump[y][i]; 57 } 58 return jump[x][0]; 59 } 60 61 inline void create(){ 62 cntNewEd = sum[root] = maxD[root] = size[root] = 0; 63 minD[root] = (long long)0x3f3f3f3f3f3f3f3f; 64 for(int i = 1 ; i <= cnt ; ++i){ 65 if(!headS) 66 s[++headS] = num[i]; 67 else{ 68 int t = jumpToLCA(num[i] , s[headS]); 69 if(t != s[headS]){ 70 while(dfn[t] < dfn[s[headS - 1]]){ 71 addEd(newEd , newHead , cntNewEd , s[headS - 1] , s[headS]); 72 --headS; 73 } 74 addEd(newEd , newHead , cntNewEd , t , s[headS]); 75 if(s[--headS] != t) 76 s[++headS] = t; 77 } 78 s[++headS] = num[i]; 79 } 80 size[num[i]] = 1; 81 minD[num[i]] = 0; 82 } 83 while(headS - 1){ 84 addEd(newEd , newHead , cntNewEd , s[headS - 1] , s[headS]); 85 --headS; 86 } 87 root = s[headS--]; 88 } 89 90 void dfs1(int now){ 91 for(int i = newHead[now] ; i ; i = newEd[i].upEd) 92 if(dep[newEd[i].end] > dep[now]){ 93 int k = newEd[i].end; 94 dfs1(k); 95 size[now] += size[k]; 96 sum[now] += sum[k] + size[k] * (dep[k] - dep[now]); 97 } 98 } 99 100 void dfs(int now){ 101 for(int i = newHead[now] ; i ; i = newEd[i].upEd) 102 if(dep[newEd[i].end] > dep[now]){ 103 int k = newEd[i].end; 104 dfs(k); 105 Sum += (size[now] - size[k]) * (sum[k] + size[k] * (dep[k] - dep[now])); 106 maxN = max(maxN , maxD[now] + maxD[k] + dep[k] - dep[now]); 107 minN = min(minN , minD[now] + minD[k] + dep[k] - dep[now]); 108 maxD[now] = max(maxD[now] , maxD[k] + dep[k] - dep[now]); 109 minD[now] = min(minD[now] , minD[k] + dep[k] - dep[now]); 110 sum[k] = size[k] = maxD[k] = 0; 111 minD[k] = (long long)0x3f3f3f3f3f3f3f3f; 112 } 113 newHead[now] = 0; 114 } 115 116 bool cmp(int a , int b){ 117 return dfn[a] < dfn[b]; 118 } 119 120 signed main(){ 121 #ifndef ONLINE_JUDGE 122 freopen("4103.in" , "r" , stdin); 123 freopen("4103.out" , "w" , stdout); 124 #endif 125 memset(minD , 0x3f , sizeof(minD)); 126 N = read(); 127 for(int i = 1 ; i < N ; ++i){ 128 int a = read() , b = read(); 129 addEd(Ed , head , cntEd , a , b); 130 addEd(Ed , head , cntEd , b , a); 131 } 132 init(1 , 0); 133 for(int M = read() ; M ; --M){ 134 Sum = maxN = 0; 135 minN = (long long)1 << 62; 136 cnt = read(); 137 for(int i = 1 ; i <= cnt ; ++i) 138 num[i] = read(); 139 sort(num + 1 , num + cnt + 1 , cmp); 140 create(); 141 dfs1(root); 142 dfs(root); 143 printf("%lld %lld %lld\n" , Sum , minN , maxN); 144 } 145 return 0; 146 }