[BZOJ3653]谈笑风生:线段树合并

分析:

发现对于每个询问p,k的答案为(siz[p]-1)*min(dep[p]-1,k)+sum{siz[q]-1,q在p子树内且dep[p]+1<=dep[q]<=dep[p]+k}。
以深度为下标做线段树合并即可。
一定记住写线段树合并的递归边界l==r。

代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
typedef long long LL;

inline int read(){
    int x=0;char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x;
}

const int MAXN=300005;
int n,q,ecnt,head[MAXN];
int maxd,dep[MAXN],siz[MAXN],root[MAXN],lc[MAXN*22],rc[MAXN*22],ql,qr,tot;
LL sum[MAXN*22],ans[MAXN];
struct Edge{
    int to,nxt;
}e[MAXN<<1];
struct Quest{
    int k,id;
};
std::vector<Quest> v[MAXN];

inline void add_edge(int bg,int ed){
    ecnt++;
    e[ecnt].to=ed;
    e[ecnt].nxt=head[bg];
    head[bg]=ecnt;
}

void dfs1(int x,int pre,int depth){
    maxd=std::max(maxd,dep[x]=depth);
    siz[x]=1;
    for(int i=head[x];i;i=e[i].nxt){
        int ver=e[i].to;
        if(ver==pre) continue;
        dfs1(ver,x,depth+1);
        siz[x]+=siz[ver];
    }
}

#define mid ((l+r)>>1)
int ins(int l,int r,int loc,int x){
    int o=++tot;
    sum[o]=x;
    if(l==r) return o;
    if(loc<=mid) lc[o]=ins(l,mid,loc,x);
    else rc[o]=ins(mid+1,r,loc,x);
    return o;
}

int mer(int x,int y,int l,int r){
    if(!x||!y) return x+y;
    if(l==r){
        sum[x]+=sum[y];
        return x;
    }
    lc[x]=mer(lc[x],lc[y],l,mid);
    rc[x]=mer(rc[x],rc[y],mid+1,r);
    sum[x]=sum[lc[x]]+sum[rc[x]];
    return x;
}

LL query(int o,int l,int r){
    if(!o) return 0;
    if(ql>qr) return 0;
    if(ql<=l&&r<=qr) return sum[o];
    LL ret=0;
    if(mid>=ql) ret+=query(lc[o],l,mid);
    if(mid<qr) ret+=query(rc[o],mid+1,r);
    return ret;
}
#undef mid

void dfs2(int x,int pre){
    for(int i=head[x];i;i=e[i].nxt){
        int ver=e[i].to;
        if(ver==pre) continue;
        dfs2(ver,x);
        root[x]=mer(root[x],root[ver],1,maxd);
    }
    for(int i=0;i<v[x].size();i++){
        int k=v[x][i].k,id=v[x][i].id;
        ans[id]=1ll*(siz[x]-1)*std::min(dep[x]-1,k);
        ql=dep[x]+1,qr=std::min(dep[x]+k,maxd);
        ans[id]+=query(root[x],1,maxd);
    }
}

int main(){
    n=read(),q=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add_edge(u,v);
        add_edge(v,u);
    }
    for(int i=1;i<=q;i++){
        int p=read(),k=read();
        v[p].push_back((Quest){k,i});
    }
    dfs1(1,0,1);
    for(int i=1;i<=n;i++) root[i]=ins(1,maxd,dep[i],siz[i]-1);
    dfs2(1,0);
    for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ErkkiErkko/p/9721783.html