[51nod 1681]公共祖先(dfs序+线段树合并)

[51nod 1681]公共祖先(dfs序+线段树合并)

题面

给出两棵n(n<=100000)个点的树,对于所有点对求它们在两棵树中公共的公共祖先数量之和。

如图,对于点对(2,4),它们在第一棵树里的公共祖先为{1,3,5},在第二棵树里的公共祖先为{1},因此公共的公共祖先数量为2

把所有点对的这个数量加起来,就得到了最终答案

分析

\(O(n^3)\)的暴力不讲了,先考虑\(O(n^2)\)的做法

枚举点对复杂度太高,不可行。我们考虑每个节点x作为公共的公共祖先的次数。设树A上的节点x,在树B上对应的节点是x'(实际上x'和x的编号是相同的,只是这样方便描述).则如果点对既在x的子树中,对应到B上后又在x'的子树中,则这个点对的公共的公共祖先就包含x .注意一个小细节,如果x是y的父亲,x不算做x和y的祖先,所以这里的“子树”应该不包含x.

如这张图中,A中1的子树中节点有{2,3,4,5},{2,3,4,5}对应到B中均在1的子树内。这4个节点中任选一对,它们的公共祖先都包含1

那么我们只要考虑x的子树中有多少个点对应过去在树B上x'的子树中即可。暴力枚举x子树中的每个节点,然后判断。设这样的点个数为cnt,则x作为公共的公共祖先的次数就是\(C_{cnt}^2\),把它累加进答案

那么我们怎么把它优化呢?我们发现,节点编号是离散的,不好判断。但子树中节点的dfs序是连续的。我们把A中节点x的dfs序标记到树B上对应的位置x‘。然后我们遍历树A的每个节点x,它子树的dfs序范围为[l[x]+1,r[x]] (不包含x)。那么问题就变成在树B上编号为x的节点的子树中有多少个节点的标记落在[l[x]+1,r[x]]的范围内

如图,我们想求A中3的子树中有多少个节点对应到B中也在3的子树里,l[3]=2,r[3]=5,B中3的子树中的dfs序有{2,4},落在[2+1,5]的范围内的只有4,所以有1个节点

这是线段树合并的经典问题。用权值线段树合并就可以了,节点x的线段树的节点[l,r] 存储有x的子树中多少个值落在[l,r]内。(有些题解用了可持久化线段树,其实没有必要)。我们遍历的时候从下往上合并,合并到节点x的时候就更新x的cnt值。

时间复杂度\(O(n\log n)\)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 100000
#define maxlogn 25 
using namespace std;
int n;
struct segment_tree{
#define lson(x) (tree[x].ls)
#define rson(x) (tree[x].rs) 
    struct node{
        int ls;
        int rs;
        int val;
    }tree[maxn*maxlogn+5];
    int ptr;
    void push_up(int x){
        tree[x].val=tree[lson(x)].val+tree[rson(x)].val;
    } 
    void update(int &x,int upos,int l,int r){
        if(!x) x=++ptr;
        if(l==r){
            tree[x].val++;
            return;
        }
        int mid=(l+r)>>1;
        if(upos<=mid) update(tree[x].ls,upos,l,mid);
        else update(tree[x].rs,upos,mid+1,r);
        push_up(x); 
    }
    int query(int x,int L,int R,int l,int r){
        if(L<=l&&R>=r){
            return tree[x].val;
        }
        int mid=(l+r)>>1;
        int ans=0;
        if(L<=mid) ans+=query(tree[x].ls,L,R,l,mid);
        if(R>mid) ans+=query(tree[x].rs,L,R,mid+1,r);
        return ans;
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y) return x+y;
        if(l==r){
            tree[x].val+=tree[y].val;
            return x;
        }
        int mid=(l+r)>>1;
        tree[x].ls=merge(tree[x].ls,tree[y].ls,l,mid);
        tree[x].rs=merge(tree[x].rs,tree[y].rs,mid+1,r);
        push_up(x);
        return x;
    }
}T;
int root[maxn+5];
int in[maxn+5];

int tim=0;
int dfnl[maxn+5],dfnr[maxn+5];
vector<int>E1[maxn+5],E2[maxn+5];
void dfs1(int x,int fa){
    dfnl[x]=++tim;
    for(int i=0;i<E1[x].size();i++){
        int y=E1[x][i];
        if(y!=fa){
            dfs1(y,x);
        }
    } 
    dfnr[x]=tim;
} 

int cnt[maxn+5];
void dfs2(int x,int fa){
    for(int i=0;i<E2[x].size();i++){
        int y=E2[x][i];
        if(y!=fa){
            dfs2(y,x);
            root[x]=T.merge(root[x],root[y],1,n);
        }
    }
    cnt[x]=T.query(root[x],dfnl[x]+1,dfnr[x],1,n);
}

int main(){
    int u,v;
    int rt1,rt2;
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        scanf("%d %d",&u,&v);
        E1[u].push_back(v);
        E1[v].push_back(u);
        in[v]++;
    }
    for(int i=1;i<=n;i++) if(in[i]==0) rt1=i;//根不一定是1 
    memset(in,0,sizeof(in));
    
    for(int i=1;i<n;i++){
        scanf("%d %d",&u,&v);
        E2[u].push_back(v);
        E2[v].push_back(u);
        in[v]++;
    }
    for(int i=1;i<=n;i++) if(in[i]==0) rt2=i;
    
    dfs1(rt1,0);
    for(int i=1;i<=n;i++){
        T.update(root[i],dfnl[i],1,n);
    }
    dfs2(rt2,0);
    long long ans=0;
    for(int i=1;i<=n;i++){
        ans+=(long long)cnt[i]*(cnt[i]-1)/2;
    }
    printf("%lld\n",ans);
}

猜你喜欢

转载自www.cnblogs.com/birchtree/p/11228847.html