https://www.51nod.com/Challenge/Problem.html#!#problemId=1681
给两棵树 问对于每对顶点 有多少除这两个点之外的点 在这两棵树上都是这两个点的公共祖先
考虑每个点的贡献 对于一个点 其子树在两棵树上的有两个dfs序 就看这两个序列有多少数是相同的 C(n,2)一下即可
至于查询两个序列的两个区间有所少数是相同的 主席树搞一下就行 这有个这类型的裸题 先把这个做明白再看这道题http://codeforces.com/problemset/problem/323/C
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
struct node1
{
int v,next;
};
struct node2
{
int l,r,val;
};
node1 edge1[maxn],edge2[maxn];
node2 tree[20*maxn];
int book[maxn],first1[maxn],first2[maxn],mp1[maxn],mpp1[maxn],sum1[maxn],mp2[maxn],mpp2[maxn],sum2[maxn],root[maxn];
int n,q,num,r1,r2;
void addedge(node1 *edge,int *first,int u,int v)
{
edge[num].v=v;
edge[num].next=first[u];
first[u]=num++;
}
void dfs(node1 *edge,int *first,int *mp,int *mpp,int *sum,int cur)
{
int i,v;
num++;
mp[cur]=num,mpp[num]=cur,sum[cur]=1;
for(i=first[cur];i!=-1;i=edge[i].next){
v=edge[i].v;
dfs(edge,first,mp,mpp,sum,v);
sum[cur]+=sum[v];
}
}
void pushup(int cur)
{
tree[cur].val=tree[tree[cur].l].val+tree[tree[cur].r].val;
}
int build(int l,int r)
{
int cur,m;
cur=num++;
tree[cur].l=0,tree[cur].r=0,tree[cur].val=0;
if(l==r) return cur;
m=(l+r)/2;
tree[cur].l=build(l,m);
tree[cur].r=build(m+1,r);
return cur;
}
int update(int rot,int tar,int l,int r)
{
int cur,m;
cur=num++;
tree[cur]=tree[rot];
if(l==r){
tree[cur].val=1;
return cur;
}
m=(l+r)/2;
if(tar<=m) tree[cur].l=update(tree[rot].l,tar,l,m);
else tree[cur].r=update(tree[rot].r,tar,m+1,r);
pushup(cur);
return cur;
}
int query(int lrot,int rrot,int pl,int pr,int l,int r)
{
int res,m;
if(pl<=l&&r<=pr){
return tree[rrot].val-tree[lrot].val;
}
res=0,m=(l+r)/2;
if(pl<=m) res+=query(tree[lrot].l,tree[rrot].l,pl,pr,l,m);
if(pr>m) res+=query(tree[lrot].r,tree[rrot].r,pl,pr,m+1,r);
return res;
}
int main()
{
ll ans,res;
int i,u,v;
scanf("%d",&n);
memset(first1,-1,sizeof(first1));
memset(book,0,sizeof(book));
num=0;
for(i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
addedge(edge1,first1,u,v);
book[v]=1;
}
for(i=1;i<=n;i++) if(!book[i]) r1=i;
memset(first2,-1,sizeof(first2));
memset(book,0,sizeof(book));
num=0;
for(i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
addedge(edge2,first2,u,v);
book[v]=1;
}
for(i=1;i<=n;i++) if(!book[i]) r2=i;
num=0;
dfs(edge1,first1,mp1,mpp1,sum1,r1);
num=0;
dfs(edge2,first2,mp2,mpp2,sum2,r2);
/*
for(i=1;i<=n;i++) printf("%d ",mp1[i]);
printf("\n");
for(i=1;i<=n;i++) printf("%d ",mp2[i]);
printf("\n");
*/
num=0;
root[0]=build(1,n);
for(i=1;i<=n;i++){
root[i]=update(root[i-1],mp2[mpp1[i]],1,n);
}
ans=0;
for(i=1;i<=n;i++){
if(mp2[i]+1<=mp2[i]+sum2[i]-1) res=query(root[mp1[i]],root[mp1[i]+sum1[i]-1],mp2[i]+1,mp2[i]+sum2[i]-1,1,n);
else res=0;
ans+=res*(res-1)/2ll;
}
printf("%lld\n",ans);
return 0;
}