【JZOJ5577】派对

Description

有一棵n个点的树,有m个点要被确定为关键点,一种确定的方案合法当且仅当存在一个点,该点到每个关键点的距离不超过k,求方案数。

Solution

考虑以一个点 x 为中心,设它到达的点个数为 fx ,那么方案数为 Cmfx
显然对于每个点这样算的话会算重,于是我们对于每个点再统计一个与它父亲(确定一个根)距离都小于等于 k 的个数 gx ,那么 ans=CmfxCmgx
fx 可以用点分治很容易求出。
gx 可以在每棵点分树上统计,判断它与父亲的关系。额外的,对于点分中心直接连出去的点(与该点是父子关系)要额外统计一下。

Code

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define ll long long
#define mem(a) memset(a,0,sizeof(a))
using namespace std;
const int N=1e5+10,M=2e5+10,mo=998244353;
int to[M],nx[M],ls[N],vl[M],num=0;
void link(int u,int v,int w){
    to[++num]=v,nx[num]=ls[u],ls[u]=num;
    vl[num]=w;
}
ll jc[N],ny[N];
int fa[N],f[N],F[N],g[N],fat[N],sz[N];
ll pow(ll x,int y){
    ll s=1;
    while(y){
        if(y&1) s=s*x%mo;
        y>>=1,x=x*x%mo;
    }
    return s;
}
ll C(int m,int n){
    if(m<n) return 0;
    return jc[m]*ny[n]%mo*ny[m-n]%mo;
}
ll K,ans=0;
int n,m;
bool vis[N];
int mn,rt,cn=0;
void getf(int x){
    rep(i,x) if(to[i]!=fa[x]) fa[to[i]]=x,getf(to[i]);
}
void getsz(int x,int fr){
    sz[x]=1;
    rep(i,x){
        int v=to[i];
        if(v==fr || vis[v]) continue;
        getsz(v,x),sz[x]+=sz[v];
    }
}
void getrt(int x,int fr,int o){
    int mx=0;
    rep(i,x){
        int v=to[i];
        if(v==fr || vis[v]) continue;
        getrt(v,x,o),mx=max(mx,sz[v]);
    }
    mx=max(mx,sz[o]-sz[x]);
    if(mn>mx) mn=mx,rt=x;
}
void findrt(int x,int fr){
    getsz(x,fr);
    mn=sz[x]+1,getrt(x,fr,x);
}
struct node{
    int x,f;
    ll t;
}d[N];
int tot=0;
void get(int x,int fr,ll t,int ft,int p){
    if(t<=K) g[p]++;
    d[++tot].x=x,d[tot].t=t,d[tot].f=ft,fat[x]=fr;
    rep(i,x){
        int v=to[i];
        if(v==fr || vis[v]) continue;
        get(v,x,t+vl[i],ft,p);
    }
}
bool cmpt(node x,node y){
    return x.t<y.t;
}
bool cmpf(node x,node y){
    return x.f<y.f || (x.f==y.f && x.t<y.t);
}
void calc(int l,int r,int z){
    int rr=r;
    fo(i,l,rr){
        while(r>=l && d[i].t+d[r].t>K) r--;
        F[d[i].x]+=z*(r-l+1);
    }
}
void dfs(int x,int fr){
    tot=0,vis[x]=1;
    rep(i,x){
        int v=to[i];
        if(v==fr || vis[v]) continue;
        get(v,x,vl[i],v,fa[v]==x?v:x);
    }
    F[x]++;
    fo(i,1,tot) if(d[i].t<=K) F[x]++,F[d[i].x]++;
    sort(d+1,d+tot+1,cmpt);
    calc(1,tot,1);
    sort(d+1,d+tot+1,cmpf);
    int p=1;
    fo(i,2,tot)
    if(d[i].f!=d[i-1].f) calc(p,i-1,-1),p=i;
    calc(p,tot,-1);
    if(fa[x]!=fr) g[x]+=F[fa[x]];
    fo(i,1,tot){
        p=d[i].x;
        if(fa[p]!=fr) g[p]+=fat[p]==fa[p]?F[p]:F[fa[p]];
    }
    fo(i,1,tot) p=d[i].x,f[p]+=F[p],F[p]=0;
    f[x]+=F[x],F[x]=0;
    rep(i,x){
        int v=to[i];
        if(v==fr || vis[v]) continue;
        findrt(v,x);
        dfs(rt,x);
    }
}
int main()
{
    scanf("%d %d %lld",&n,&m,&K);
    fo(i,2,n){
        int u,v,w;
        scanf("%d %d %d",&u,&v,&w);
        link(u,v,w),link(v,u,w);
    }
    getf(1),findrt(1,0),dfs(rt,0);
    jc[0]=1;
    fo(i,1,n) jc[i]=jc[i-1]*i%mo;
    ny[n]=pow(jc[n],mo-2);
    fd(i,n-1,0) ny[i]=ny[i+1]*(i+1)%mo;
    fo(i,1,n){
        ans=(ans+C(f[i],m))%mo;
        if(i>1) ans=(ans-C(g[i],m)+mo)%mo;
    }
    printf("%lld",ans*jc[m]%mo);
}

猜你喜欢

转载自blog.csdn.net/sadnohappy/article/details/79568902
今日推荐