在某些时候,我们需要维护树上选一些点所得到的东西。
这些东西要满足这样一个性质:未选的点可以通过某种方式删除而不影响最终的结果。
最典型的就是求被选出的节点在原树上的距离之和。
既然我们知道未选的点可以删掉,那么我们就想办法建一棵树,使得树上的未选点尽量少。
这棵树就叫虚树。
那么要怎么建立一棵虚树呢?
首先我们在原树上跑一遍
,并得出树上节点的
序,记为
。
(顺便树链剖分维护一下
)
然后我们按照这个
从小到大把节点插入虚树。
维护一个栈,它表示在当前的这棵虚树上,以最后一个插入的点为终点的
链。
设最后插入的点为
(就是栈顶的点),当前要加入的点为
。我们想把
插入到我们已经构建的虚树上去。
求出
,记为
。有两种情况:
和
分立在
的两棵子树下。
是
。
/*为什么
不可能是
呢?
因为如果
是
,说明
,而我们是按照
序号从小到大选点的,于是
,矛盾。*/
那么对于第二种情况,显然只需要把
连到
上面就行了。
对于第一种情况呢,有
,那么这说明什么呢?
这说明我们已经把
所在的子树中,
所在的子树全部遍历完了。
/*为什么遍历完了呢?
如果没有遍历完,那么肯定有一个未加入的点
,满足
,我们按照
序号递增顺序遍历的话,肯定会把
加进来了才到
。*/
这样,我们就直接构建
为根的,
所在的那个子树。
由于我们的栈维护的是当前的
链,所以显然我们可以在退栈的时候连边,那么考虑一下不是退栈需要连的边。
所在的子树如果还有其它部分,它一定在之前就构建好了(所有退栈的点都已经被正确地连入树中了),就剩那条
链。
那么要如何正确地连
到
的边呢?
设栈顶的节点为
,栈顶第二个节点为
。
重复以下操作:
如果
,可以直接连边
,然后退栈。
如果
,说明
就是
,直接连边
,此时子树已经构建完毕。
如果
,说明
被
与
夹在中间,此时连边
,退一次栈,再把
入栈。
这样就连完了,接下来把
入栈即可。
好像很复杂对吧,我们观察一下这样连边的本质是什么。
/* 你快观察呀.jpg */
然后我们会发现,这么讨论太复杂了,我们直接利用它们在原树中的深度关系来连边即可。
首先我们得到一个点,还是记为
。
然后我们重复以下操作:
得到栈顶和
的
,还是记为
。
并且把栈顶记为
,栈顶第二个节点(如果有)记为
。
如果
,则连边
;
否则如果
,则连边
;
否则跳出。
接着如果
还不是栈顶的话,则把
入栈,然后把
入栈即可。
这样写起来简单些。
接下来我们看道例题:
计蒜客 青云的机房组网方案
给出一棵树,每个点有点权,边权均为
。
求所有点权互质的点对的距离和。
注意到本题可以转化为求 树上所有的点对距离之和
所有不互质的点对距离之和。
我们可以利用容斥原理来计算不互质的点对之和,这个步骤可以写个线性筛质数预处理。
然而对于每一个因数,我们在树上取的点并不多,且多次取的总和是
级别的,又发现对于两个顶点,它们之间的距离不会因为这两点之间路径上点的多少而改变。因此就可以考虑对于每一次询问(一个因数相当于一个询问),我们根据原树的信息重新建一棵树,让这棵树里面尽量少包含未选择的节点。(于是这棵树就是虚树)然后在这棵虚树上跑一个树形
就行了。
代码如下:
#include <bits/stdc++.h>
#define R register
#define LL long long
#define Max(__a,__b) (__a<__b?__b:__a)
#define Min(__a,__b) (__a<__b?__a:__b)
using namespace std;
template<class TT>inline void read(R TT &x){
x=0;R bool f=false;R char c=getchar();
for(;c<48||c>57;c=getchar())f|=(c=='-');
for(;c>47&&c<58;c=getchar())x=(x<<1)+(x<<3)+(c^48);
(f)&&(x=-x);
}
template<class orzyrt>inline orzyrt Abs(R orzyrt x){
if(x<0)return -x;
else return x;
}
int n;
namespace non_baoli{
#define N 100010
int mul[N];
char com[N];
int pri[N];
inline void get_prime(R int cnt=0){
mul[1]=1;
for(R int i=2;i<100001;++i){
if(!com[i])pri[cnt++]=i,mul[i]=-1;
for(R int j=0;j<cnt&&i*pri[j]<100001;++j){
com[i*pri[j]]=1;
if(i%pri[j]==0){
mul[i*pri[j]]=0;
break;
}else mul[i*pri[j]]=-mul[i];
}
}
}
struct Edge{
int to;
Edge *next;
}E[N<<1],E1[N<<1],*head[N],*head1[N],*e=E,*st=E1;
inline void add(R int u,R int v){
*e=(Edge){v,head[u]};head[u]=e++;
}
inline void add1(R int u,R int v){
*st=(Edge){v,head1[u]};head1[u]=st++;
}
int fa[N],son[N],dep[N],siz[N],top[N],dfn[N],dfs_clo;
void dfs1(R int u,R int f){
dfn[u]=++dfs_clo;
siz[u]=1;fa[u]=f;
dep[u]=dep[f]+1;
R int v;
for(R Edge *i=head[u];i;i=i->next){
if((v=i->to)==f)continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v])son[u]=v;
}
}
void dfs2(R int u,R int tp){
top[u]=tp;
if(son[u])dfs2(son[u],tp);
for(R Edge*i=head[u];i;i=i->next){
R int v=i->to;
if(v!=fa[u]&&v!=son[u])dfs2(v,v);
}
}
inline int lca(R int a,R int b){
while(top[a]!=top[b]){
dep[top[a]]>dep[top[b]]?
a=fa[top[a]]:b=fa[top[b]];
}
return dep[a]<dep[b]?a:b;
}
int a[N],Top,cnt,sz[N],stk[N];
LL val;
void dfs3(R int u,R int f){
for(R Edge *i=head[u];i;i=i->next){
if(i->to!=f){
dfs3(i->to,u);
sz[u]+=sz[i->to];
}
}
if(f)val+=1ll*Abs(dep[u]-dep[f])*sz[u]*(cnt-sz[u]);
}
void del(R int u,R int f){
for(R Edge *i=head[u];i;i=i->next){
if(i->to!=f)del(i->to,u);
}
sz[u]=0;
head[u]=0;
}
inline bool cmp(R int a,R int b){return dfn[a]<dfn[b];}
inline LL solve(R int num){
cnt=Top=0;
for(R int u=num;u<100001;u+=num){
for(R Edge *i=head1[u];i;i=i->next){
a[cnt++]=i->to;
}
}
if(cnt<=1)return 0;
e=E;
val=0;
sort(a,a+cnt,cmp);
for(R int i=0;i<cnt;++i)sz[a[i]]=1;
for(R int i=0,Lca,now;i<cnt;++i){
Lca=0;now=a[i];
while(Top>0){
Lca=lca(now,stk[Top]);
if(Top>1&&dep[Lca]<dep[stk[Top-1]]){
R int u=stk[Top],v=stk[Top-1];
add(u,v);add(v,u);
Top--;
}else if(dep[Lca]<dep[stk[Top]]){
R int u=Lca,v=stk[Top];
add(u,v);add(v,u);
Top--;
break;
}else break;
}
if(stk[Top]!=Lca)stk[++Top]=Lca;
stk[++Top]=now;
}
while(Top>1){
R int u=stk[Top],v=stk[Top-1];
add(u,v);add(v,u);
Top--;
}
dfs3(a[0],0);
del(a[0],0);
return val*mul[num];
}
inline void work(){
get_prime();
for(R int i=1,x;i<=n;++i){
read(x);
add1(x,i);
}
for(R int i=1,u,v;i<n;++i){
read(u);read(v);
add(u,v);add(v,u);
}
dfs_clo=0;
dfs1(1,0);
dfs2(1,1);
R LL ans=0;
memset(head,0,sizeof head);
for(R int i=1;i<100001;++i){
if(mul[i])ans+=solve(i);
}
printf("%lld\n",ans);
}
}
int main(){
read(n);
non_baoli::work();
return 0;
}