title
Description
设 T 为一棵有根树,我们做如下的定义:
• 设 a 和 b 为 T 中的两个不同节点。如果 a 是 b 的祖先,那么称“a 比 b 不知道高明到哪里去了”。
• 设 a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数 x,那么称“a 与 b 谈笑风生”。
给定一棵 n 个节点的有根树 T,节点的编号为 1 ∼ n,根节点为 1 号节点。你需要回答 q 个询问,询问给定两个整数 p 和 k,问有多少个有序三元组 \((a,b,c)\) 满足:
- a、 b 和 c 为 T 中三个不同的点,且 a 为 p 号节点;
- a 和 b 都比 c 不知道高明到哪里去了;
- a 和 b 谈笑风生。这里谈笑风生中的常数为给定的 k。
analysis
线段树合并。
因为要求三元组 \((a,b,c)\) 中 \(a,b\) 都是 \(c\) 的祖先,所以只有两种情况
\(b\) 是 \(a\) 的祖先
这种情况直接统计就好了,答案为 \(\min\{depth(a)-1,k\}(size(a)-1)\) 其中 \(depth(a)\)为 \(a\) 的深度,\(size(a)\) 为子树大小
\(a\) 是 \(b\) 的祖先
随便选一个满足要求的 \(b\),然后 \(b\) 子树中任选一个 \(c\) 都能满足答案。答案为 \(\sum\limits_{depth(a)+k≥depth(b)} size(b)-1\)。要要计算这个东西,对于每一个结点维护一棵以深度为编号,以 \(size(x)-1\) 为权值的线段树,每次将子树的线段树合并上来。
注意要开 long long
。
code
#include<bits/stdc++.h>
typedef long long ll;
const int maxn=3e5+10;
typedef int iarr[maxn];
namespace IO
{
char buf[1<<15],*fs,*ft;
inline char getc() { return (ft==fs&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),ft==fs))?0:*fs++; }
template<typename T>inline void read(T &x)
{
x=0;
T f=1, ch=getchar();
while (!isdigit(ch) && ch^'-') ch=getchar();
if (ch=='-') f=-1, ch=getchar();
while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48), ch=getchar();
x*=f;
}
char Out[1<<24],*fe=Out;
inline void flush() { fwrite(Out,1,fe-Out,stdout); fe=Out; }
template<typename T>inline void write(T x,char str)
{
if (!x) *fe++=48;
if (x<0) *fe++='-', x=-x;
T num=0, ch[20];
while (x) ch[++num]=x%10+48, x/=10;
while (num) *fe++=ch[num--];
*fe++=str;
}
}
using IO::read;
using IO::write;
template<typename T>inline T min(T a,T b) { return a<b ? a : b; }
template<typename T>inline T max(T a,T b) { return a>b ? a : b; }
int ver[maxn<<1],Next[maxn<<1],head[maxn],len;
inline void add(int x,int y)
{
ver[++len]=y,Next[len]=head[x],head[x]=len;
}
namespace SGT
{
struct Orz{int l,r;ll z;}c[maxn*30];
int num=0;
inline void Change(int &x,int l,int r,int k,int z)
{
if (!x) x=++num;
c[x].z+=z;
if (l==r) return ;
int mid=(l+r)>>1;
if (k<=mid) Change(c[x].l,l,mid,k,z);
else Change(c[x].r,mid+1,r,k,z);
}
inline ll query(int x,int l,int r,int tl,int tr)
{
if (!x || tr<tl) return 0;
if (tl<=l && r<=tr) return c[x].z;
int mid=(l+r)>>1; ll ans=0;
if (tl<=mid) ans+=query(c[x].l,l,mid,tl,tr);
if (tr>mid) ans+=query(c[x].r,mid+1,r,tl,tr);
return ans;
}
inline int merge(int x,int y)
{
if (!x || !y) return x|y;
int t=++num;
c[t].l=merge(c[x].l,c[y].l);
c[t].r=merge(c[x].r,c[y].r);
c[t].z=c[x].z+c[y].z;
return t;
}
}
using SGT::Change;
using SGT::query;
using SGT::merge;
iarr rt,siz,dep;
int n,q;
inline void dfs(int x,int f)
{
siz[x]=1;
dep[x]=dep[f]+1;
for (int i=head[x]; i; i=Next[i])
{
int y=ver[i];
if (y==f) continue;
dfs(y,x);
siz[x]+=siz[y];
}
Change(rt[x],1,n,dep[x],siz[x]-1);
if (f) rt[f]=merge(rt[f],rt[x]);
}
int main()
{
read(n);read(q);
for (int i=1,x,y; i<n; ++i) read(x),read(y),add(x,y),add(y,x);
dfs(1,0);
while (q--)
{
int x,y;read(x);read(y);
write(query(rt[x],1,n,dep[x]+1,dep[x]+y)+1ll*(siz[x]-1)*min(dep[x]-1,y),'\n');
}
IO::flush();
return 0;
}