[NOI.AC省选模拟赛3.23] 染色 [点分治+BFS序]

题面

传送门

重要思想

真的是没想到,我很久以来一直以为总会有应用的$BFS$序,最终居然是以这种方式出现在题目中

笔记:$BFS$序可以用来处理限制点对距离的题目(综合点分树使用)

思路

本题中首先询问可以拆成两个:所有同色点对距离大于$L-1$的种数减去所有同色点对距离大于$R$的种数

考虑如何解决点对距离大于$k-1$:

我们考虑树的$bfs$序,假设当前按照$bfs$序加入了点$u$,深度为$d$

考虑树里另外两个已经加入了的点$x,y$,显然它们的深度都小于等于$d$

那么有结论:如果$dis(x,u)\leq k$且$dis(y,u)\leq k$,那么$dis(x,y)\leq k$

结论的证明:考虑$u$到$x,y$的路径第一次分离的点

在这个点上,有两种情况:两条路都去某个儿子,或者一个去父亲一个去儿子

可以发现不论哪种情况,最终$dis(x,y)$都不超过$k$

那么我们可以把问题变成:按照$bfs$序加入点,每次每个点不能选择的颜色数等于在加入它的时候同它的距离小于$k$的点数

注意上述的点集中任意两个点的距离都小于$k$,所以一定颜色互不相同

这就避免了原图想法中可能出现“$x,y$和$y,z$颜色不同,但是$x,z$颜色一样”的情况

解决距离小于$k$的点数使用点分树即可

Code

请务必不要吐槽作者的难看的点分树写法【悲】

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#include<queue>
#define MOD 1000000007
#define ll long long
using namespace std;
inline int read(){
    int re=0,flag=1;char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') flag=-1;
        ch=getchar();
    }
    while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
    return re*flag;
}
int n,m,L,R,first[100010],cnte=-1;
struct edge{
    int to,next;
}a[200010];
inline void add(int u,int v){
    a[++cnte]=(edge){v,first[u]};first[u]=cnte;
    a[++cnte]=(edge){u,first[v]};first[v]=cnte;
}
int vis[100010],pre[100010],siz[100010],son[100010],root,sum;
int dis0[100010][20],dis1[100010][20],st[100010][20],dep[100010];
namespace BIT0{
    int a[4000010],st[100010],lim[100010],cur=0;
    inline void init(int u,int len){
        st[u]=cur;lim[u]=len;
        cur+=len;
    }
    inline void add(int u,int x){
        x++;
        for(;x<=lim[u];x+=(x&(-x))) a[st[u]+x]++;
    }
    inline int sum(int u,int x){
        if(x<0) return 0;
        x++;int re=0;
        x=min(x,lim[u]);
        for(;x;x^=(x&(-x))) re+=a[st[u]+x];
        return re;
    }
}
namespace BIT1{
    int a[4000010],st[100010],lim[100010],cur=0;
    inline void init(int u,int len){
        st[u]=cur;lim[u]=len;
        cur+=len;
    }
    inline void add(int u,int x){
        x++;
        for(;x<=lim[u];x+=(x&(-x))) a[st[u]+x]++;
    }
    inline int sum(int u,int x){
        if(x<0) return 0;
        x++;int re=0;
        x=min(x,lim[u]);
        for(;x;x^=(x&(-x))) re+=a[st[u]+x];
        return re;
    }
}
void getroot(int u,int f){
    int i,v;
    siz[u]=1;son[u]=0;
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(vis[v]||v==f) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        if(son[u]<siz[v]) son[u]=siz[v];
    }
    son[u]=max(son[u],sum-siz[u]);
    if(son[u]<son[root]) root=u;
}
int dfs1(int u,int f,int d,int num){
    int i,v,re=d;
    dis1[u][num]=d;
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(v==f||vis[v]) continue;
        re=max(re,dfs1(v,u,d+1,num));
    }
    return re;
}
int dfs0(int u,int f,int d,int num,int root){
    int i,v,re=d;
    dis0[u][num]=d;st[u][num]=root;
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(v==f||vis[v]) continue;
        re=max(re,dfs0(v,u,d+1,num,root));
    }
    return re;
}
void build(int u,int fa,int ori,int cursum){
    int i,v,tmp;
    dep[u]=dep[fa]+1;
    if(fa){
        tmp=dfs1(ori,fa,1,dep[u]);
        BIT1::init(u,tmp+1);
    }
    tmp=dfs0(u,0,0,dep[u],u);
    BIT0::init(u,tmp+1);
    vis[u]=1;pre[u]=fa;
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(vis[v]) continue;
        sum=(siz[u]>siz[v])?siz[v]:cursum-siz[u];
        root=0;son[root]=sum;
        getroot(v,0);
        build(root,u,v,sum);
    }
}
inline void change(int u){
    BIT0::add(u,0);
    for(int i=dep[u]-1;i>=1;i--){
        //BIT0(dis0[u][i]): u to st[u][i]
        //BIT1(dis1[u][i+1]): u to st[u][i] only in subtree st[u][i+1]
        BIT0::add(st[u][i],dis0[u][i]);
        BIT1::add(st[u][i+1],dis1[u][i+1]);
    }
}
inline int query(int u,int d){
    int re=BIT0::sum(u,d),i;
    for(i=dep[u]-1;i>=1;i--){
        re+=BIT0::sum(st[u][i],d-dis0[u][i]);
        re-=BIT1::sum(st[u][i+1],d-dis1[u][i+1]);
    }
    return re;
}
int bfs[100010];queue<int>q;
int main(){
    memset(first,-1,sizeof(first));
    n=read();m=read();L=read();R=read();int i,u,v,t1,t2;
    int ans0=1,ans1=1;
    for(i=1;i<n;i++){
        t1=read();t2=read();
        add(t1,t2);
    }
    root=0;sum=n;son[root]=sum;
    getroot(1,0);
    build(root,0,1,sum);

    q.push(1);bfs[1]=1;
    while(!q.empty()){
        u=q.front();q.pop();
        t1=m-query(u,L-1);
        t2=max(0,m-query(u,R));
        change(u);
        if(m<=0){puts("0");return 0;}
        ans0=1ll*ans0*t1%MOD;
        ans1=1ll*ans1*t2%MOD;
        for(i=first[u];~i;i=a[i].next){
            v=a[i].to;if(bfs[v]) continue;
            q.push(v);bfs[v]=1;
        }
    }
    cout<<(ans0-ans1+MOD)%MOD<<'\n';
}

猜你喜欢

转载自www.cnblogs.com/dedicatus545/p/10590339.html
今日推荐