POJ 1741 Tree

题意

一棵 \(n\) 个节点的树,求两点间距离不超过 \(k\) 的点对对数。

思路

点分治模板题,点分治一般用来解决树上的路径问题,核心在于找出重心,算经过重心的合法路径,然后以重心把树劈成若干个小树分别计算,每次能把树至少砍一半,至多递归到 $\log n $ 层,而每层总结点数是 \(n\) ,所以复杂度为 \(n \log n\)

点分除了“算经过重心的合法路径”这一步每题不一样外,其他部分差不多,大体代码框架如下:

int sz[N];bool mark[N];
void CFS(int u,int f,int tot,int &C,int &Mi)//Centroid Finding Search
{
    sz[u]=1;int res=0;
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v==f||mark[v])continue;
        CFS(v,u,tot,C,Mi);
        sz[u]+=sz[v];
        res=max(res,sz[v]);
    }
    res=max(res,tot-sz[u]);
    if(res<Mi)C=u,Mi=res;
}
void solve(int u)
{
    ...
}
void dac(int u,int tot)
{
    int Mi=1e9;
    CFS(u,0,tot,u,Mi);
    mark[u]=1;
    solve(u);
    
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(mark[v])continue;
        dac(v,sz[v]<sz[u]?sz[v]:tot-sz[u]);
    }
}

这道题,对于一个要 \(\text{solve}\) 的节点 \(u\)\(\text{dfs}\) 出所有从 \(u\) 出发的路径的权值,排序,统计任意权值和不超过 \(k\) 的路径对对数,用类似尺取的方法就可以数出经过 \(u\) 的路径条数。然而事实上,当有两条路径同时经过了某个与 \(u\) 相邻的节点 \(v\) ,那么两条路径就有了一条公共边 \((u,v)\) ,就不合法了,所以应把这种情况容斥掉,具体请看代码。

其实不容斥也可以写,\(\text{solve}\) \(u\) 节点的时候只要将子树一棵一棵加进去,每加一棵之前查询一下加进去的子树与马上要加的子树能配出多少条路径,用树状数组维护(用\(\text{short}\) 卡内存) ,需要一个查询的 \(\text{dfs}\) 和更新的 \(\text{dfs}\) 。这里不放这种写法的代码。

代码

#include<iostream>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define FOR(i,x,y) for(int i=(x),i##END=(y);i<=i##END;++i)
#define DOR(i,x,y) for(int i=(x),i##END=(y);i>=i##END;--i)
typedef long long LL;
using namespace std;
const int N=3e4+5;
template<const int maxn,const int maxm>struct Linked_list
{
    int head[maxn],to[maxm],cost[maxm],nxt[maxm],tot;
    Linked_list(){clear();}
    void clear(){memset(head,-1,sizeof(head));tot=0;}
    void add(int u,int v,int w){to[++tot]=v,cost[tot]=w,nxt[tot]=head[u],head[u]=tot;}
    #define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};Linked_list<N,N<<1>G;
int sz[N];bool mark[N];
int A[N],Ac;
int n,K,ans;

void CFS(int u,int f,int tot,int &C,int &Mi)
{
    sz[u]=1;int res=0;
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v==f||mark[v])continue;
        CFS(v,u,tot,C,Mi);
        sz[u]+=sz[v];
        res=max(res,sz[v]);
    }
    res=max(res,tot-sz[u]);
    if(res<Mi)C=u,Mi=res;
}
void clct(int u,int f,int sum)
{
    A[Ac++]=sum;
    EOR(i,G,u)
    {
        int v=G.to[i],w=G.cost[i];
        if(v==f||mark[v])continue;
        clct(v,u,sum+w);
    }
}
int solve(int u,int l)
{
    Ac=0;clct(u,0,l);
    sort(A,A+Ac);
    int res=0,i=0,j=Ac-1;
    while(i<j)
    {
        while(i<j&&A[i]+A[j]>K)j--;
        res+=j-i;
        i++;
    }
    return res;
}
void dac(int rt,int tot)
{
    int u,Mi=1e9;
    CFS(rt,0,tot,u,Mi);
    mark[u]=1;
    ans+=solve(u,0);
    EOR(i,G,u)
    {
        int v=G.to[i],w=G.cost[i];
        if(mark[v])continue;
        ans-=solve(v,w);
        dac(v,min(sz[v],sz[rt]-sz[u]));
    }
}

int main()
{
    while(scanf("%d%d",&n,&K),n||K)
    {
        G.clear();
        memset(mark,0,sizeof(mark));
        FOR(i,1,n-1)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            G.add(u,v,w);
            G.add(v,u,w);
        }
        ans=0;
        dac(1,n);
        printf("%d\n",ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Paulliant/p/10159305.html