题意
\(T\) 组数据 \((1 \le T \le 5)\), 每次给定一棵 \(n\) 个点的树 \((1 \le n \le 299995)\).
设 \(E\) 为树的边集, \(V'_x,\ V'_y\) 分别为删去边 \((x,y)\) 后 点 \(x\) 所在的点集和点 \(y\) 所在的点集.
求 \[ \sum_{(x,y) \in E} \left( \sum_{x \in V'_x} [x 是 V'_x 的重心] * x + \sum_{y \in V'_y} [y 是 V'_y 的重心] * y \right) \]
思路
40 pts
暴力枚举每一条边, 求重心即可.
100 pts
做法 1
总的思路是从枚举边变为枚举点,
具体操作 :
枚举每一个点 \(u\), 对该点的每一棵子树都处理出 \(w_i\) 表示, 该子树内权值为 \(i\) 的边的个数,
这里边的权值定义为 : 以 \(u\) 为根节点时, 删去这条边后, 删掉的节点个数.
再分类讨论当前枚举到的子树是不是权值最大的子树, 找到一个区间 \([l,r]\), 使删去的边的权值 \(e_i \in [l,r]\) 时, 点 \(u\) 为新树的重心, 用树状数组区间求和即可.
问题在于, 我们每次要独立地获得每个子树的边的数量, 所以就要用到线段树合并或者主席树, 不然的话每次重新获取子树的边的复杂度会达到 \(O(n^2)\), 然后本人线段树合并与主席树都不会, 所以....
做法 2
突破口 : 树的重心一定会在根节点的重路径上.
那么我们就可以预处理出重路径, 并用倍增优化, 枚举每一条边时换根转移信息即可.
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=299995+7;
const int L=20;
int T,n,f[N],sz[N],son[N],ses[N],s[N][L+7];
ll ans;
int lst[N],nxt[2*N],to[2*N],tot;
void add(int x,int y){ nxt[++tot]=lst[x]; to[tot]=y; lst[x]=tot; }
void upd(int u){
for(int i=1;i<=L;i++)
s[u][i]=s[s[u][i-1]][i-1];
}
void pre(int u,int fa){
sz[u]=1; f[u]=fa;
son[u]=ses[u]=0;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
pre(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]){
ses[u]=son[u];
son[u]=v;
}
else if(sz[v]>sz[ses[u]]) ses[u]=v;
}
s[u][0]=son[u];
upd(u);
}
void work(int u){
int tot=sz[u],rt=u;
for(int i=L;i>=0;i--)
if(tot-sz[s[u][i]]<=tot/2)
u=s[u][i];
if(sz[s[u][0]]<=tot/2) ans+=u;
if(u!=rt&&sz[u]<=tot/2) ans+=f[u];
}
void run(int u,int fa){
if(fa){ work(fa); work(u); }
int flag=0,t1;
if(n-sz[u]>sz[son[u]]){
flag=1;
t1=ses[u];
ses[u]=son[u];
son[u]=fa;
}
else if(n-sz[u]>sz[ses[u]]){
flag=2;
t1=ses[u];
ses[u]=fa;
}
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
sz[u]+=sz[fa]-sz[v]; f[u]=v;
if(v==son[u]) s[u][0]=ses[u];
else s[u][0]=son[u];
upd(u);
run(v,u);
sz[u]-=sz[fa]-sz[v]; f[u]=fa;
}
if(flag==1){
son[u]=ses[u];
ses[u]=t1;
}
else if(flag==2) ses[u]=t1;
s[u][0]=son[u];
upd(u);
}
int main(){
//freopen("cg.in","r",stdin);
//freopen("cg.out","w",stdout);
cin>>T;
while(T--){
scanf("%d",&n);
memset(lst,0,sizeof(lst));
tot=ans=0;
int x,y;
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
pre(1,0);
run(1,0);
printf("%lld\n",ans);
}
return 0;
}