分析:
发现对于每个询问p,k的答案为(siz[p]-1)*min(dep[p]-1,k)+sum{siz[q]-1,q在p子树内且dep[p]+1<=dep[q]<=dep[p]+k}。
以深度为下标做线段树合并即可。
一定记住写线段树合并的递归边界l==r。
代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
typedef long long LL;
inline int read(){
int x=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x;
}
const int MAXN=300005;
int n,q,ecnt,head[MAXN];
int maxd,dep[MAXN],siz[MAXN],root[MAXN],lc[MAXN*22],rc[MAXN*22],ql,qr,tot;
LL sum[MAXN*22],ans[MAXN];
struct Edge{
int to,nxt;
}e[MAXN<<1];
struct Quest{
int k,id;
};
std::vector<Quest> v[MAXN];
inline void add_edge(int bg,int ed){
ecnt++;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
void dfs1(int x,int pre,int depth){
maxd=std::max(maxd,dep[x]=depth);
siz[x]=1;
for(int i=head[x];i;i=e[i].nxt){
int ver=e[i].to;
if(ver==pre) continue;
dfs1(ver,x,depth+1);
siz[x]+=siz[ver];
}
}
#define mid ((l+r)>>1)
int ins(int l,int r,int loc,int x){
int o=++tot;
sum[o]=x;
if(l==r) return o;
if(loc<=mid) lc[o]=ins(l,mid,loc,x);
else rc[o]=ins(mid+1,r,loc,x);
return o;
}
int mer(int x,int y,int l,int r){
if(!x||!y) return x+y;
if(l==r){
sum[x]+=sum[y];
return x;
}
lc[x]=mer(lc[x],lc[y],l,mid);
rc[x]=mer(rc[x],rc[y],mid+1,r);
sum[x]=sum[lc[x]]+sum[rc[x]];
return x;
}
LL query(int o,int l,int r){
if(!o) return 0;
if(ql>qr) return 0;
if(ql<=l&&r<=qr) return sum[o];
LL ret=0;
if(mid>=ql) ret+=query(lc[o],l,mid);
if(mid<qr) ret+=query(rc[o],mid+1,r);
return ret;
}
#undef mid
void dfs2(int x,int pre){
for(int i=head[x];i;i=e[i].nxt){
int ver=e[i].to;
if(ver==pre) continue;
dfs2(ver,x);
root[x]=mer(root[x],root[ver],1,maxd);
}
for(int i=0;i<v[x].size();i++){
int k=v[x][i].k,id=v[x][i].id;
ans[id]=1ll*(siz[x]-1)*std::min(dep[x]-1,k);
ql=dep[x]+1,qr=std::min(dep[x]+k,maxd);
ans[id]+=query(root[x],1,maxd);
}
}
int main(){
n=read(),q=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
for(int i=1;i<=q;i++){
int p=read(),k=read();
v[p].push_back((Quest){k,i});
}
dfs1(1,0,1);
for(int i=1;i<=n;i++) root[i]=ins(1,maxd,dep[i],siz[i]-1);
dfs2(1,0);
for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
return 0;
}