题意:给定带点权的有根树,给定正整数k,对于每颗子树,假设根节点是rt,对于每对(x,y)满足,LCA(x,y)==rt,x!=rt,y!=rt
且dis(x,y)==k,的节点,可以给这颗子树贡献val[x]+val[y]的值,求出每颗子树的值。
- 离线+子树查询,我们考虑用dsu on tree。
- 维护
cnt[i]
表示深度为i的点的点权和,num[i]
表示深度为i的点的个数,ret
记录贡献。 - 注意ret这个变量是无法保存给父节点的,所以每求完一个子树就令ret=0,同时若干棵子树的信息有且只能保留一棵,否则会造成答案重复。(于是实锤用dsu on tree
对于任意子树rt,假设其有x个孩子,答案只会产生于不同孩子子树的节点上。所以我们枚举x个孩子节点,先求答案,再把这个孩子子树的信息更新。这样就能保证同一棵孩子子树中满足dis(x,y)==k的节点的答案不会被记录了。在当前子树中,根节点不会产生贡献,所以不需要特别计算。
#include<bits/stdc++.h>
using namespace std;
//#pragma GCC optimize(2)
#define ull unsigned long long
#define ll long long
#define pii pair<int, int>
#define pdd pair<double, double>
#define re register
#define lc rt<<1
#define rc rt<<1|1
const int maxn = 1e5 + 10;
const ll mod = 998244353;
const ll inf = (ll)4e17+5;
const int INF = 1e9 + 7;
const double pi = acos(-1.0);
ll inv(ll b){
if(b==1)return 1;return(mod-mod/b)*inv(mod%b)%mod;}
ll cnt[maxn];//第i层的点权和
int num[maxn];//第i层的点个数
vector<int> g[maxn];
int n,k;
int val[maxn];
int siz[maxn],dep[maxn],son[maxn];
ll ret,ans[maxn];
void dfs1(int rt,int fa)
{
siz[rt]=1;
dep[rt]=dep[fa]+1;
for(int i:g[rt])
{
if(i==fa) continue;
dfs1(i,rt);
siz[rt]+=siz[i];
if(siz[i] > siz[son[rt]]) son[rt]=i;
}
}
int root;
void add(int rt,int fa) //更新答案
{
int d=k+2*dep[root]-dep[rt];
if(d>0)
ret+=1ll*num[d]*val[rt]+cnt[d];
for(int i:g[rt])
{
if(i==fa) continue;
add(i,rt);
}
}
void upd(int rt,int fa,int v) //更新节点信息
{
num[dep[rt]]+=v;
cnt[dep[rt]]+=val[rt]*v;
for(int i:g[rt])
{
if(i==fa) continue;
upd(i,rt,v);
}
}
void dfs2(int rt,int fa,bool ok)
{
for(int i:g[rt])
{
if(i==fa || i==son[rt]) continue;
dfs2(i,rt,0);
}
if(son[rt]) dfs2(son[rt],rt,1);
root=rt;//当前根节点
for(int i:g[rt])
{
if(i==son[rt] || i==fa) continue;
add(i,rt);//顺序是先求答案 再更新节点信息
upd(i,rt,1);
}
num[dep[rt]]++;
cnt[dep[rt]]+=val[rt];
ans[rt]=ret;
ret=0;//注意统计完一棵子树就要清空ret 因为这个ret无法继承给父节点 这里wa了好久
if(!ok) upd(rt,fa,-1);
}
int main()
{
scanf("%d %d",&n,&k);
for(int i=1;i<=n;i++) scanf("%d",val+i);
for(int i=1,u,v;i<n;i++)
{
scanf("%d %d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0);
dfs2(1,0,0);
printf("%lld",ans[1]);
for(int i=2;i<=n;i++) printf(" %lld",ans[i]);
return 0;
}