点分治详解

点分治详解

点分治是一个需要自己推导的算法,但是有板子,但是Cal这个函数根据不同题目是会变的。

点分治是解决树上求值的一种算法,比如说一棵树上路径 ( u , v ) <= K 的点对数量。

我们首先思考最笨的想法:

我们可以先求出所以路径长度,然后减去不满足的路径长度(也就是最近公共祖先不是跟的路径长度)。

DFS枚举子树,然后在DFS一趟计算这棵子树路径长度 不满足的,然后排序找答案。

这样子看起来效率很高啊 O ( N l o g 2 N ) ,但是对于一种数据就会很慢,。如果树退化成了链的话,那么时间复杂度是 O ( N 2 ) 的。

所以,点分治的想法,每次保证树的深度是平均的,也就是最大子树最小的点作为根节点。

其实也就这样,对于笨蛋想法的优化。

下面看一道例题:

POJ1741

【题目大意】

树上点对距离小于等于K的个数。

了解一下变量

int n,K;//题目给出
int Rot,RotSize;//根,当前子树大小
int Siz[MAXN];//子树x的大小
int dst[MAXN];//路径长度
int Maxs[MAXN];//x为根最大子树大小
LL Ans;//答案
vector<int> Now;//路径
bool vis[MAXN];
struct Edge{//邻接表存边
    int tot,lnk[MAXN],son[MAXN<<1],nxt[MAXN<<1],W[MAXN<<1];
    void clean(){memset(lnk,0,sizeof(lnk));tot=0;}
    void Add(int x,int y,int z){son[++tot]=y;W[tot]=z;nxt[tot]=lnk[x];lnk[x]=tot;}
}E;

首先我们要找出根节点

子节点的大小好求,父节点所在的子树大小怎么办呢?

其实很好求,树的大小 x的子树大小。

void Get_Rot(int x,int fa){
    Siz[x]=1;Maxs[x]=0;
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!vis[E.son[j]]&&E.son[j]!=fa){
        Get_Rot(E.son[j],x);
        Siz[x]+=Siz[E.son[j]];
        Maxs[x]=max(Siz[E.son[j]],Maxs[x]);
    }
    Maxs[x]=max(Maxs[x],RotSize-Siz[x]);
    if(Maxs[x]<Maxs[Rot]) Rot=x;
}

然后就是求路径的长度,那么DFS去刷就可以了

void Get_Dst(int x,int f){
    Siz[x]=1;Now.push_back(dst[x]);
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(E.son[j]!=f&&!vis[E.son[j]]){
        dst[E.son[j]]=dst[x]+E.W[j];
        Get_Dst(E.son[j],x);Siz[x]+=Siz[E.son[j]];
    }
}

接下来就是计算了

int Cal(int x,int y){
    int Ret=0;
    Now.clear();dst[x]=y;Get_Dst(x,0);//先求出路径长度
    sort(Now.begin(),Now.end());//然后排序找这个<=K的值
    for(int l=0,r=Now.size()-1;l<r;l++){
        while(Now[l]+Now[r]>K&&l<r) r--;
        Ret+=r-l;
    }
    return Ret;
}

枚举子树

void Solve(int x){
    Ans+=Cal(x,0);vis[x]=1;//容斥的想法,计算所有答案
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!vis[E.son[j]]){
        Ans-=Cal(E.son[j],E.W[j]);//减去不满足答案
        Maxs[0]=RotSize=Siz[E.son[j]];Rot=0;
        Get_Rot(E.son[j],0);Solve(Rot); 
    }
}

下面贴上完整代码:

#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define MAXN 10005
#define LL long long
using namespace std;
int n,K,Rot,RotSize;
int Siz[MAXN],dst[MAXN],Maxs[MAXN];LL Ans;
vector<int> Now;
bool vis[MAXN];
struct Edge{
    int tot,lnk[MAXN],son[MAXN<<1],nxt[MAXN<<1],W[MAXN<<1];
    void clean(){memset(lnk,0,sizeof(lnk));tot=0;}
    void Add(int x,int y,int z){son[++tot]=y;W[tot]=z;nxt[tot]=lnk[x];lnk[x]=tot;}
}E;
void Get_Rot(int x,int fa){
    Siz[x]=1;Maxs[x]=0;
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!vis[E.son[j]]&&E.son[j]!=fa){
        Get_Rot(E.son[j],x);
        Siz[x]+=Siz[E.son[j]];
        Maxs[x]=max(Siz[E.son[j]],Maxs[x]);
    }
    Maxs[x]=max(Maxs[x],RotSize-Siz[x]);
    if(Maxs[x]<Maxs[Rot]) Rot=x;
}
void Get_Dst(int x,int f){
    Siz[x]=1;Now.push_back(dst[x]);
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(E.son[j]!=f&&!vis[E.son[j]]){
        dst[E.son[j]]=dst[x]+E.W[j];
        Get_Dst(E.son[j],x);Siz[x]+=Siz[E.son[j]];
    }
}
int Cal(int x,int y){
    int Ret=0;
    Now.clear();dst[x]=y;Get_Dst(x,0);
    sort(Now.begin(),Now.end());
    for(int l=0,r=Now.size()-1;l<r;l++){
        while(Now[l]+Now[r]>K&&l<r) r--;
        Ret+=r-l;
    }
    return Ret;
}
void Solve(int x){
    Ans+=Cal(x,0);vis[x]=1;
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!vis[E.son[j]]){
        Ans-=Cal(E.son[j],E.W[j]);
        Maxs[0]=RotSize=Siz[E.son[j]];Rot=0;
        Get_Rot(E.son[j],0);Solve(Rot); 
    }
}
int main(){
    #ifndef ONLINE_JUDGE
    freopen("POJ1741.in","r",stdin);
    freopen("POJ1741.out","w",stdout);
    #endif
    while(scanf("%d%d",&n,&K),n||K){
        memset(vis,0,sizeof(vis));
        E.clean();
        for(int i=1;i<n;i++){
            int x,y,z;scanf("%d%d%d",&x,&y,&z);
            E.Add(x,y,z);E.Add(y,x,z);
        }
        Ans=0;Maxs[0]=RotSize=n;Rot=0;
        Get_Rot(1,0);Solve(Rot);
        printf("%lld\n",Ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41357771/article/details/80853105