[CodeChef May Challenge 2018]Edges in Spanning Trees

Description

给出两棵树,T1和T2
对于T1中的每一条边e1,你需要求出T2中有多少条边e2满足
1:T1-e1+e2是一棵树
2:T2-e2+e1是一棵树
n<=2e5

Solution

我们考虑一组限制的两种方法,并且这两种方法能够套在一起
首先,我们知道可以对于T2中的每一条边(u,v),在第一棵树上的u,v打上标记,在lca(u,v)处撤销
这样子我们可以在遍历的时候求出那些e2边可能满足条件
接下来对于每个e1(x,y)满足条件的e2需要存在在T2中x到y的路径上
也就是我们需要满足单点修改,路径询问,然后用线段树合并维护
这个并没有必要用树链剖分,可以直接括号序维护
复杂度O(n log n)

Code

#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a,b) for(int i=lst[b][a];i;i=nxt[i])
using namespace std;

typedef vector<int> vec;
#define pb(a) push_back(a)

int read() {
    char ch;
    for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
    int x=ch-'0';
    for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x;
}

void write(int x) {
    if (!x) {putchar('0');putchar(' ');return;}
    char ch[20];int tot=0;
    for(;x;x/=10) ch[++tot]=x%10+'0';
    fd(i,tot,1) putchar(ch[i]);
    putchar(' ');
}

const int N=4e5+5;

int t[N<<1],nxt[N<<1],lst[2][N],l;
void add(int x,int y,int a) {
    t[++l]=y;nxt[l]=lst[a][x];lst[a][x]=l;
}

int n,dfn[2][N],fir[2][N],dep[2][N],tot,f[2][N][19],lg[N];

void dfs(int x,int y,int a) {
    dfn[a][++tot]=x;fir[a][x]=tot;dep[a][x]=dep[a][y]+1;
    rep(i,x,a) if (t[i]!=y) dfs(t[i],x,a),dfn[a][++tot]=x;
}

int lca(int x,int y,int a) {
    x=fir[a][x];y=fir[a][y];
    if (x>y) swap(x,y);
    int z=lg[y-x+1];
    x=f[a][x][z];y=f[a][y-(1<<z)+1][z];
    return dep[a][x]<dep[a][y]?x:y; 
}

int in[N],out[N];

void travel(int x,int y) {
    in[x]=++tot;
    rep(i,x,1) if (t[i]!=y) travel(t[i],x);
    out[x]=++tot;
}

int u1[N],u2[N],v1[N],v2[N],an[N];
vec q[N];

int tr[N<<5],ls[N<<5],rs[N<<5],rt[N],cnt;

void modify(int &v,int l,int r,int x,int y) {
    if (!v) v=++cnt;
    if (l==r) {tr[v]+=y;return;}
    int mid=l+r>>1;
    if (x<=mid) modify(ls[v],l,mid,x,y);
    else modify(rs[v],mid+1,r,x,y);
    tr[v]=tr[ls[v]]+tr[rs[v]];
}

int merge(int x,int y,int l,int r) {
    if (!x||!y) return x+y;
    if (l==r) {tr[x]+=tr[y];return x;}
    int mid=l+r>>1;
    ls[x]=merge(ls[x],ls[y],l,mid);
    rs[x]=merge(rs[x],rs[y],mid+1,r);
    tr[x]=tr[ls[x]]+tr[rs[x]];
    return x;
}

int query(int v,int l,int r,int x,int y) {
    if (!v) return 0;
    if (l==x&&r==y) return tr[v];
    int mid=l+r>>1;
    if (y<=mid) return query(ls[v],l,mid,x,y);
    else if (x>mid) return query(rs[v],mid+1,r,x,y);
    else return query(ls[v],l,mid,x,mid)+query(rs[v],mid+1,r,mid+1,y);
}

void solve(int x,int y) {
    rep(i,x,0) if (t[i]!=y) solve(t[i],x);
    rep(i,x,0) rt[x]=merge(rt[x],rt[t[i]],1,tot);
    if (!q[x].empty())
        fo(i,0,q[x].size()-1) {
            int z=q[x][i];
            if (z>0) modify(rt[x],1,tot,in[z],1),modify(rt[x],1,tot,out[z],-1);
            else modify(rt[x],1,tot,in[-z],-1),modify(rt[x],1,tot,out[-z],1);
        }
    if (y) {
        int z=lca(x,y,1);
        an[x]=query(rt[x],1,tot,1,in[x])+query(rt[x],1,tot,1,in[y])-2*query(rt[x],1,tot,1,in[z]);
    }
}

int main() {
    for(int ty=read();ty;ty--) {
        n=read();
        memset(lst,0,sizeof(lst));l=0;
        fo(i,1,n-1) {
            u1[i]=read();v1[i]=read();
            add(u1[i],v1[i],0);add(v1[i],u1[i],0);
        }
        fo(i,1,n-1) {
            u2[i]=read();v2[i]=read();
            add(u2[i],v2[i],1);add(v2[i],u2[i],1);
        }
        tot=0;dfs(1,0,0);
        tot=0;dfs(1,0,1);
        fo(a,0,1) fo(i,1,tot) f[a][i][0]=dfn[a][i];
        fo(i,1,tot) lg[i]=log(i)/log(2);
        fo(a,0,1)
            fo(j,1,18)
                fo(i,1,tot-(1<<j)+1) {
                    int x=f[a][i][j-1],y=f[a][i+(1<<j-1)][j-1];
                    f[a][i][j]=dep[a][x]<dep[a][y]?x:y;
                }
        tot=0;travel(1,0);
        fo(i,1,cnt) tr[i]=ls[i]=rs[i]=0;cnt=0;
        fo(i,1,n) q[i].clear(),rt[i]=0;
        fo(i,1,n-1) {
            int x=u2[i],y=v2[i],z=lca(x,y,0);
            int val=dep[1][x]<dep[1][y]?y:x;
            q[x].pb(val);q[y].pb(val);
            q[z].pb(-val);q[z].pb(-val);
        }
        solve(1,0);
        fo(i,1,n-1) {
            int x=u1[i],y=v1[i];
            write(dep[0][x]<dep[0][y]?an[y]:an[x]);
        }
        puts("");
    }
    return 0;
} 

猜你喜欢

转载自blog.csdn.net/alan_cty/article/details/80353019