求树上距离:LCA与点分治两种解法

杭电2586 戳我

题意

给出一个加权无根树,多组询问每两点之间的距离。

输入格式

第一行T,代表有T组测试
后面代表每一组测试的格式
输入树点的个数n以及输入询问的次数m,再输入n-1条边以及该边的权值以及x,y,代表询问x,y之间的距离
数据范围:T<=10 n<=40000 m<=200

输出格式

输出所有测试询问点对之间的距离

暴力做法

对于每一组询问,x,y,直接dfs一遍求出x到y的距离即可。每一组测试时间复杂度:O(n*m),8e6 明显也能过
代码:

#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false)
#define maxn 40005

struct edge{
    int to,next,len;
}G[maxn*2];
int head[maxn],num;
void add(int from,int to,int len)
{
    G[++num].next=head[from];
    G[num].to=to;
    G[num].len=len;
    head[from]=num;
}

int dis[maxn];
void init()
{
    memset(head,-1,sizeof(head));
    num=0;
}

void dfs(int fa,int cur)
{
    for(int i=head[cur];i!=-1;i=G[i].next)
    {
        int v=G[i].to;
        int len=G[i].len;
        if(v!=fa)
        {
            dis[v]=dis[cur]+G[i].len;
            dfs(cur,v);
        }
    }
}

int main()
{
    IOS;
	int T;
	cin>>T;
	while(T--)
    {
        init();
        int n,m;
        cin>>n>>m;

        for(int i=0;i<n-1;i++)
        {
            int x,y,z;
            cin>>x>>y>>z;
            add(x,y,z),add(y,x,z);
        }

        for(int i=0;i<m;i++)
        {
            int x,y;
            cin>>x>>y;
            dis[x]=0;
            dfs(x,x);
            cout<<dis[y]<<"\n";
        }
    }
	return 0;
}

一、LCA求树上距离

解法:设u,v的最近公共祖先为lca,记录每一点到根节点(这里根节点可以随便设一个,从该点开始搜索即可)的距离,用dir数组表示,则两点之间的距离公式如下:dir[u]+dir[v]-2*dir[lca],即u到根节点的距离+v到根节点的距离 - 2倍u,v公共祖先到根节点的距离。
如下图所示,即因为u,v到根节点有公共部分,减去这一部分即为u,v的距离。

在这里插入图片描述
下面贴上两种求法,欧拉序+ST表O(nlogn)、tarjan+并查集 O(kn)

Dfs+ST表
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define maxn 40005
#define IOS ios::sync_with_stdio(false)

inline void read(int& x){
    char ch = getchar();
    x = 0;
    for(;ch < '0' || ch > '9';ch = getchar());
    for(;ch >= '0' && ch <= '9';ch = getchar()) x = x*10+(ch-'0');
}

struct edge
{
    int next,to,dis;
};

edge ff[2*maxn];
int head[2*maxn],num;
void add(int from,int to,int dis)
{
    ff[++num].next=head[from];
    ff[num].to=to;
    ff[num].dis=dis;
    head[from]=num;
}

int Eu[2*maxn],id;//记录欧拉序列
int dd[maxn];//每一点的深度
int hId[maxn];//记录每一点对应欧拉序列第一次出现的位置(映射)
int dir[maxn];//记录根节点到每个点的距离
bool vis[maxn];

void dfs(int x,int d)
{
    dd[x]=d;
    Eu[++id]=x;
    hId[x]=id;
    vis[x]=true;
    for(int i=head[x];i;i=ff[i].next)
    {
        if(vis[ff[i].to]==false)
        {
            dir[ff[i].to]=dir[x]+ff[i].dis;
            dfs(ff[i].to,d+1);
            Eu[++id]=x;
        }
    }
}

int f[2*maxn][32];
int res[2*maxn][32];
void ST_Init(int n)
{
    for(int i=1; i<=n; i++)
    {
        f[i][0]=dd[Eu[i]];
        res[i][0]=Eu[i];
    }
    int depth=(int)(log2(n));
    for(int j=1; j<=depth; j++)
        for(int i=1; i<=n-(1<<j)+1; i++)
        {
            if(f[i][j-1]<f[i+(1<<(j-1))][j-1])
            {
                f[i][j]=f[i][j-1];
                res[i][j]=res[i][j-1];
            }
            else
            {
                f[i][j]=f[i+(1<<(j-1))][j-1];
                res[i][j]=res[i+(1<<(j-1))][j-1];
            }
        }
}

int ST_Query(int l,int r)
{
    int p=(int)(log2(r-l+1));
    if(f[l][p]<f[r-(1<<p)+1][p])
        return res[l][p];
    else
        return res[r-(1<<p)+1][p];
}


int main()
{
    //IOS;
    int t;
    read(t);
    while(t--)
    {
        memset(vis,false,sizeof(vis));
        memset(dir,false,sizeof(dir));
        memset(head,0,sizeof(head));
        num=0,id=0;
        int n,m;
        read(n),read(m);

        for(int i=0; i<n-1; i++)
        {
            int x,y,dis;
            read(x),read(y),read(dis);
            add(x,y,dis);
            add(y,x,dis);
        }

        dfs(1,1);
        ST_Init(id);

        for(int i=0; i<m; i++)
        {
            int x,y,l,r;
            read(x),read(y);
            if(hId[x]<hId[y])
                l=hId[x],r=hId[y];
            else
                l=hId[y],r=hId[x];
            int lca=ST_Query(l,r);
            printf("%d\n",dir[x]+dir[y]-2*dir[lca]);
        }
    }

    return 0;
}
Targan+并查集
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define maxn 40005

struct node
{
    int value,order;
};

struct edge
{
    int next,to,dis;
};

edge ff[2*maxn];
int head[2*maxn],num;
void add_edge(int from,int to,int dis)
{
    ff[++num].next=head[from];
    ff[num].to=to;
    ff[num].dis=dis;
    head[from]=num;
}

vector<node>qq[maxn];
bool vis[maxn];
int yy[maxn],ans[maxn],dir[maxn];
int in_qq[maxn][2];//记录访问的左右端点

inline void read(int& x)
{
    char ch = getchar();
    x = 0;
    for(; ch < '0' || ch > '9'; ch = getchar());
    for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x*10+(ch-'0');
}

int find(int x)
{
    return yy[x]==x?x:yy[x]=find(yy[x]);
}

void dfs(int from,int to)
{
    vis[to]=true;
    for(int i=head[to]; i; i=ff[i].next)
        if(ff[i].to!=from)
        {
            dir[ff[i].to]=dir[to]+ff[i].dis;
            dfs(to,ff[i].to);
        }
    for(int i=0; i<qq[to].size(); i++)
        if(vis[qq[to][i].value]&&ans[qq[to][i].order]==0)
            ans[qq[to][i].order]=find(qq[to][i].value);
    yy[to]=find(from);
}

int main()
{
    int t;
    read(t);
    while(t--)
    {
        num=0;
        memset(head,0,sizeof(head));
        memset(dir,0,sizeof(dir));
        memset(vis,false,sizeof(vis));
        int n,m;
        read(n),read(m);
        for(int i=0; i<n-1; i++)
        {
            int x,y,dis;
            read(x),read(y),read(dis);
            add_edge(x,y,dis);
            add_edge(y,x,dis);
        }

        for(int i=0; i<m; i++)
        {
            int x,y;
            read(x),read(y);
            in_qq[i+1][0]=x,in_qq[i+1][1]=y;
            qq[x].push_back((node)
            {
                y,i+1
            });
            qq[y].push_back((node)
            {
                x,i+1
            });
        }

        for(int i=0; i<=n; i++)yy[i]=i;
        dfs(1,1);
        for(int i=1; i<=m; i++)
        {
            printf("%d\n",dir[in_qq[i][0]]+dir[in_qq[i][1]]-2*dir[ans[i]]);
        }
    }
    return 0;
}

二、点分治求树上距离

点分治的基本步骤:找到当前树的重心,然后dfs一遍求出所有点到重心的距离,很容易发现u到重心的距离+v到重心的距离就是u,v的距离。但是会发现这些得到的答案不完全正确,因为可能u,v可能在同一子树上:

在这里插入图片描述
A为重心,求B,C之间的距离,这样求得的答案就是a+b+2*c。所以还需要继续对每一个子树分治,重复上面的过程,对每次求出两点间的距离取最小值。需要求O(log(n))次重心,每次求重心O(n),复杂度为O(nlog(n)。

#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false)
const int N = 40005, INF = 0x7f7f7f7f;

int n, m,head[N], num, tot;//tot记录当前子树点数
int dis[N], flag[N];
int visited[N],id;
//dis记录子树每一点到根节点的距离,flag用于删除根节点
int size[N], Max[N], root;
int ans[N];
struct node{
    int first,second;
};
node query[N];

struct edge
{
    int next, to,len;
} G[N * 2];
void add(int from, int to,int len)
{
    G[++num].next = head[from];
    G[num].to = to;
    G[num].len=len;
    head[from] = num;
}


void input(void)
{
    scanf("%d",&m);
    for (int i = 1; i<n; i++)
    {
        int x, y,z;
        scanf("%d%d%d",&x,&y,&z);
        add(x, y,z), add(y, x,z);
    }

    for(int i=0; i<m; i++)
    {
        ans[i]=INF;
        int x,y;
        scanf("%d%d",&x,&y);
        query[i]=node{x,y};
    }
}

void dp(int fa, int cur)//求树的重心
{
    size[cur] = 1, Max[cur] = 0;
    for (int i = head[cur]; i; i = G[i].next)
    {
        int v = G[i].to;
        if (flag[v] || v == fa) continue;
        dp(cur, v);
        size[cur] += size[v];
        if(Max[cur]<size[v])
            Max[cur] = size[v];
    }
    if(Max[cur]<tot-size[cur])
        Max[cur] = tot - size[cur];
    if (Max[root] > Max[cur]) root = cur;
}

void dfs(int fa, int cur)
{
    visited[id++]=cur;
    for (int i = head[cur]; i; i = G[i].next)
    {
        int v = G[i].to;
        if (v == fa || flag[v]) continue;
        dis[v]=dis[cur]+G[i].len;
        dfs(cur, v);
    }
}

void calc(int x, int len)
{
    id=0;
    dis[x] = len;
    dfs(0, x);
    /*cout<<"root="<<root<<",";
    for(int i=1;i<=n;i++)
        cout<<dis[i]<<" ";
    cout<<"\n";*/

    for(int i=0; i<m; i++)
    {
        int x=query[i].first,y=query[i].second;
        if(dis[x]!=-1&&dis[y]!=-1&&dis[x]+dis[y]<ans[i])
            ans[i]=dis[x]+dis[y];
    }

    for(int i=0; i<id; i++)
    {
        int id=visited[i];
        dis[id]=-1;
    }
}

void divide(int x)
{
    flag[x] = true;//删去根节点
    calc(x, 0);
    //cout<<"ans="<<ans<<"\n";
    for (int i = head[x]; i; i = G[i].next)
    {
        int y = G[i].to;
        if (flag[y]) continue;
        //ans -= calc(y, G[i].len);//点对在同一子树的情况
        tot = size[y], root = 0;
        dp(0, y);
        divide(root);
    }
}

void reset(void)
{
    num = 0;

    memset(head,0,sizeof(head));
    memset(flag,0,sizeof(flag));
    memset(dis,-1,sizeof(dis));
    tot = n;
    root = 0, Max[0] = INF;
}

int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        reset();
        input();

        dp(0, 1);
        divide(root);

        for(int i=0; i<m; i++)
            printf("%d\n",ans[i]);
    }

    return 0;
}

发布了41 篇原创文章 · 获赞 2 · 访问量 1219

猜你喜欢

转载自blog.csdn.net/qq_41418281/article/details/104045603