洛谷 P4149 [ IOI 2011 ] Race —— 点分治

题目:https://www.luogu.org/problemnew/show/P4149

仍然是点分治;

不过因为是取 min ,所以不能用容斥,那么子树之间就必须分开算,记录桶时注意这个;

每次 memset 桶会很慢,可以用栈记录修改的地方,然后改回来即可;

注意更新 getrt 中 sum 的方式,可以 dfs 时顺便重新算一下 siz,但也可以利用原树求出来的 siz,判断一下当前的儿子在原树中是儿子还是父亲;

那么就要传个参数,是当前的所有点个数,在原树中是父亲的话就用总个数 - siz[to[i]],这个做法比较快。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int const maxn=2e5+5,maxm=1e6+5,inf=0x3f3f3f3f;
int n,K,hd[maxn],ct,to[maxn<<1],nxt[maxn<<1],w[maxn<<1],dis[maxn],siz[maxn];
int sum,rt,tmp[maxm],ans=inf,mx,sta[maxn][3],f[maxn],top;
bool vis[maxn];
void add(int x,int y,int z){to[++ct]=y; nxt[ct]=hd[x]; w[ct]=z; hd[x]=ct;}
void getrt(int x,int fa)
{
    siz[x]=1; int nmx=0;//局部变量! 
    for(int i=hd[x],u;i;i=nxt[i])
    {
        if((u=to[i])==fa||vis[u])continue;
        getrt(u,x);
        siz[x]+=siz[u]; nmx=max(nmx,siz[u]);
    }
    nmx=max(nmx,sum-siz[x]);
    if(nmx<mx)mx=nmx,rt=x;
}
void dfs(int x,int fa)//siz 不管的话 RE 2个点 
{
    siz[x]=1;
    for(int i=hd[x],u;i;i=nxt[i])
    {
        if((u=to[i])==fa||vis[u])continue;
        dis[u]=dis[x]+w[i]; f[u]=f[x]+1;
        if(dis[u]<=K)
        {
            ans=min(ans,f[u]+tmp[K-dis[u]]);
            sta[++top][0]=dis[u]; sta[top][1]=f[u];
        }
        dfs(u,x); 
        siz[x]+=siz[u];
    }
}
int work(int x,int ss)
{
    vis[x]=1; int p=1;//局部变量 
    for(int i=hd[x],u;i;i=nxt[i])
    {
        if(vis[u=to[i]])continue;
        dis[u]=w[i]; f[u]=1;
        if(dis[u]<=K)
        {
            ans=min(ans,f[u]+tmp[K-dis[u]]);
            sta[++top][0]=dis[u]; sta[top][1]=f[u];
        }
        dfs(u,0);
        for(int w;p<=top;p++)tmp[w=sta[p][0]]=min(tmp[w],sta[p][1]);
    }
    for(int i=1;i<=top;i++)tmp[sta[i][0]]=inf; top=0;
    for(int i=hd[x],u;i;i=nxt[i])
    {
        if(vis[u=to[i]])continue;
        sum=(siz[u]>siz[x]?ss-siz[x]:siz[u]); mx=inf; getrt(u,0); work(rt,sum);
        //可以这样更新sum     //u在原树中是x的儿子或父亲 
    }
}
int main()
{
    scanf("%d%d",&n,&K);
    for(int i=1,x,y,z;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z); 
        add(x,y,z); add(y,x,z);
    }
    memset(tmp,0x3f,sizeof tmp); tmp[0]=0;//
    sum=n; mx=inf; getrt(1,0); 
    work(rt,sum);
    printf("%d\n",ans==inf?-1:ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Zinn/p/9476983.html