hdu 5111 Alexandra and Two Trees (主席树+树剖)

题意:

给你两颗树,q次查询,每次查询,问你第一颗树u1到v1,第二棵树u2到v2上,出现的数字交集大小,每颗树上不会有重复数字

思路:

     我们首先先想这样一个问题,给你两个数列 a 1 , a 2 , a 3 . . . a n b 1 , b 2 , b 3 , . . . b m ,怎么统计。我们可以用将 a 1 , a 2 , a 3 . . . a n 离散为1,2,3…,n,然后用b数列与离散的值对应,如果没有对应的值,就标为0。
     如果统计的是两个数列的子区间 l 1 , r 1 , l 2 , r 2 的答案,我们可以对b数列建立一颗主席树,那么我们对于 l 1 , r 1 , l 2 , r 2 的询问,我们就相当于在主席树 l 2 . . . r 2 的这几棵树里面, l 1 . . . r 1 这个区间出现了几个数,这就是主席树的常规查找了
     题目给出的是树,我们用树剖就能化为区间问题了

错误及反思:

     变量名太多了,都想不出名字了,结果一乱wa了好几发

代码:

#include<bits/stdc++.h>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
const int N =100010;

struct E{
    int to,next;
}e[N*2];
int tot,tid,n,k,q;

int top[N],si[N],fa[N],first[N],son[N],depth[2][N],id[N],val[2][N];

int to[N][25];
int sum[N*20],ls[N*20],rs[N*20],root[N],Cnt;

vector<int> sonn[N];
map<int,int> mp;

void addedge(int x,int y){
    e[tot].to=y;
    e[tot].next=first[x];
    first[x]=tot++;
    e[tot].to=x;
    e[tot].next=first[y];
    first[y]=tot++;
}


void dfs1(int now,int bef,int dep){
    fa[now]=bef;
    depth[0][now]=dep;
    si[now]=1;
    for(int i=first[now];i!=-1;i=e[i].next)
        if(e[i].to!=bef){
            dfs1(e[i].to,now,dep+1);
            si[now]+=si[e[i].to];
            if(son[now]==-1) son[now]=e[i].to;
            else son[now]=si[e[i].to]>si[son[now]]?e[i].to:son[now];
        }
}

void dfs2(int now,int tp){
    top[now]=tp;
    id[now]=tid;
    mp[val[0][now]]=tid++;
    if(son[now]!=-1) dfs2(son[now],tp);
    for(int i=first[now];i!=-1;i=e[i].next)
        if(e[i].to!=fa[now]&&e[i].to!=son[now])
            dfs2(e[i].to,e[i].to);
}

void initialization(){
    tot=0; tid=1; Cnt=0;
    memset(first,-1,sizeof(first));
    memset(son,-1,sizeof(son));
    mp.clear();
}

int lca(int a,int b){
    if(depth[1][a]>depth[1][b]) swap(a,b);
    for(int i=22;i>=0;i--)
        if(depth[1][to[b][i]]>=depth[1][a])
            b=to[b][i];
    if(a==b) return a;
    for(int i=22;i>=0;i--){
        if(to[a][i]!=to[b][i]){
            a=to[a][i];
            b=to[b][i];
        }
    }
    return to[a][0];
}

int query(int L,int R,int u,int v,int lc,int flc,int l,int r){
    if(L<=l&&R>=r) return sum[u]+sum[v]-sum[lc]-sum[flc];
    int m=(l+r)/2;
    int num=sum[ls[u]]+sum[ls[v]]-sum[ls[lc]]-sum[ls[flc]],ans=0;
    if(m>=L) ans+=query(L,R,ls[u],ls[v],ls[lc],ls[flc],l,m);
    if(m<R) ans+=query(L,R,rs[u],rs[v],rs[lc],rs[flc],m+1,r);
    return ans;
}

void getans(int L,int R,int l,int r){
    int f1=top[L],f2=top[R],tlca=lca(l,r);
    int cnt=0;
    while(f1!=f2){
        if(depth[0][f1]<depth[0][f2]){
            swap(f1,f2);
            swap(L,R);
        }
        cnt+=query(id[f1],id[L],root[l],root[r],root[tlca],root[to[tlca][0]],0,tid-1);
        L=fa[f1];
        f1=top[L];
    }
    if(depth[0][L]>depth[0][R]) swap(L,R);
    cnt+=query(id[L],id[R],root[l],root[r],root[tlca],root[to[tlca][0]],0,tid-1);
    printf("%d\n",cnt);
}

int init(int l,int r){
    int rt=++Cnt;
    sum[rt]=0,ls[rt]=rt,rs[rt]=rt;
    if(l==r) return rt;
    int m=(l+r)/2;
    ls[rt]=init(l,m);
    rs[rt]=init(m+1,r);
    return rt;
}

void dfs(int now,int FA,int dep){
    to[now][0]=FA;
    depth[1][now]=dep;
    for(int i=0;i<sonn[now].size();i++)
        if(sonn[now][i]!=FA)
            dfs(sonn[now][i],now,dep+1);
}

void initlca(){
    for(int i=1;i<23;i++)
        for(int j=0;j<=k;j++)
            to[j][i]=to[to[j][i-1]][i-1];
}



int update(int pre,int x,int l, int r){
    int rt=++Cnt;
    ls[rt]=ls[pre], rs[rt]=rs[pre], sum[rt]=sum[pre]+1;
    if(l==r) return rt;
    int m=(l+r)>>1;
    if(x<=m) ls[rt]=update(ls[pre], x, l,m);
    else rs[rt]=update(rs[pre], x, m+1,r);
    return rt;
}

void build(int now,int fa){
    root[now]=update(root[fa],mp[val[1][now]],0,tid-1);
    for(int i=0;i<sonn[now].size();i++)
        if(sonn[now][i]!=fa)
            build(sonn[now][i],now);
}

int main(){
    while(scanf("%d",&n)!=EOF){
        initialization();
        for(int i=2,v;i<=n;i++){
            scanf("%d",&v);
            addedge(i,v);
        }
        for(int i=1;i<=n;i++) scanf("%d",&val[0][i]);

        dfs1(1,1,1); dfs2(1,1);

        scanf("%d",&k);
        for(int i=2,v;i<=k;i++){
            scanf("%d",&v);
            sonn[v].push_back(i);
        }
        sonn[0].push_back(1);
        for(int i=1;i<=k;i++) scanf("%d",&val[1][i]);
        root[0]=init(0,tid-1);
        dfs(0,0,1); initlca();
        build(1,0);
        scanf("%d",&q);
        while(q--){
            int u1,v1,u2,v2;
            scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
            getans(u1,v1,u2,v2);
        }
        for(int i=0;i<=k;i++) sonn[i].clear();
    }
}

猜你喜欢

转载自blog.csdn.net/roll_keyboard/article/details/80653785