E 旗鼓相当的对手 (dsu on tree)

传送门

题意:给定带点权的有根树,给定正整数k,对于每颗子树,假设根节点是rt,对于每对(x,y)满足,LCA(x,y)==rt,x!=rt,y!=rt且dis(x,y)==k,的节点,可以给这颗子树贡献val[x]+val[y]的值,求出每颗子树的值。

  • 离线+子树查询,我们考虑用dsu on tree。
  • 维护cnt[i]表示深度为i的点的点权和,num[i]表示深度为i的点的个数,ret记录贡献。
  • 注意ret这个变量是无法保存给父节点的,所以每求完一个子树就令ret=0,同时若干棵子树的信息有且只能保留一棵,否则会造成答案重复。(于是实锤用dsu on tree

对于任意子树rt,假设其有x个孩子,答案只会产生于不同孩子子树的节点上。所以我们枚举x个孩子节点,先求答案,再把这个孩子子树的信息更新。这样就能保证同一棵孩子子树中满足dis(x,y)==k的节点的答案不会被记录了。在当前子树中,根节点不会产生贡献,所以不需要特别计算。

#include<bits/stdc++.h>
using namespace std;
//#pragma GCC optimize(2)
#define ull unsigned long long
#define ll long long
#define pii pair<int, int>
#define pdd pair<double, double>
#define re register
#define lc rt<<1
#define rc rt<<1|1
const int maxn = 1e5 + 10;
const ll mod = 998244353;
const ll inf = (ll)4e17+5;
const int INF = 1e9 + 7;
const double pi = acos(-1.0);
ll inv(ll b){
    
    if(b==1)return 1;return(mod-mod/b)*inv(mod%b)%mod;}
ll cnt[maxn];//第i层的点权和
int num[maxn];//第i层的点个数
vector<int> g[maxn];
int n,k;
int val[maxn];
int siz[maxn],dep[maxn],son[maxn];
ll ret,ans[maxn];
void dfs1(int rt,int fa)
{
    
    
	siz[rt]=1;
	dep[rt]=dep[fa]+1;
	for(int i:g[rt]) 
	{
    
    
		if(i==fa) continue;
		dfs1(i,rt);
		siz[rt]+=siz[i];
		if(siz[i] > siz[son[rt]]) son[rt]=i;
	}
}
int root;
void add(int rt,int fa) //更新答案
{
    
    
	int d=k+2*dep[root]-dep[rt];
	if(d>0)
		ret+=1ll*num[d]*val[rt]+cnt[d];
	for(int i:g[rt]) 
	{
    
    
		if(i==fa) continue;
		add(i,rt);
	}
}
void upd(int rt,int fa,int v) //更新节点信息
{
    
    
	num[dep[rt]]+=v;
	cnt[dep[rt]]+=val[rt]*v;
	for(int i:g[rt]) 
	{
    
    
		if(i==fa) continue;
		upd(i,rt,v);
	}
}
void dfs2(int rt,int fa,bool ok) 
{
    
    
	for(int i:g[rt]) 
	{
    
    
		if(i==fa || i==son[rt]) continue;
		dfs2(i,rt,0);
	}
	if(son[rt]) dfs2(son[rt],rt,1);
	root=rt;//当前根节点
	for(int i:g[rt]) 
	{
    
    
		if(i==son[rt] || i==fa) continue;
		add(i,rt);//顺序是先求答案 再更新节点信息
		upd(i,rt,1);
	}
	num[dep[rt]]++;
	cnt[dep[rt]]+=val[rt];
	ans[rt]=ret;
	ret=0;//注意统计完一棵子树就要清空ret 因为这个ret无法继承给父节点 这里wa了好久
	if(!ok) upd(rt,fa,-1);
}
int main()
{
    
    
	scanf("%d %d",&n,&k);
	for(int i=1;i<=n;i++) scanf("%d",val+i);
	for(int i=1,u,v;i<n;i++)
	{
    
    
		scanf("%d %d",&u,&v);
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs1(1,0);
	dfs2(1,0,0);
    printf("%lld",ans[1]);
	for(int i=2;i<=n;i++) printf(" %lld",ans[i]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_46030630/article/details/120650066