UPCOJ-5531 [COCI 2017-2018-2] - Usmjeri

usmjeri(2s256M)

给一棵N个节点的树,编号从1到N,再给定m对点(u,v),你要将树上的每条无向边变为有向边,使得给定的点对都满足u能到达v或v能到达u。问有多少种不同的方案,答案对109+7求余。

输入:

第一行两个正整数​N and ​M(1 ≤ ​N, M≤ 3·10​ 5​ ),表示树的结点个数,和点对的个数。

接下来N-1行,每行两个整数,表示树上的边。

接下来M行,每行两个不同的正整数(ai,bi,表示对应的点对,点对互不相同。

输出

一行一个数,表示不同的方案数模109+7

20%的数据树是一个链,即第i个点连在i+1上。

40%的数据N,M≤​ ​5·10​ 3 .

样例:

input input input

4​ ​1

1​ ​2

2​ ​3

3​ ​4

2​ ​4

Output

4


题目大意:给一颗树,把它的边变成有向边,求有多少种方案可以满足给出的u[i]能到v[i]或v[i]到u[i]


分析:

首先把边放在点上,我们把每个点都多增加一个点,表示和它相反的边那么每次u->lca v->lca 并在一起

如果lca不是u和v 那么把u+n与v并一起,v+n与u并在一起

最后如果有u与u+n在一个并查集,说明矛盾了

否则答案就是2^(并查集个数/2)


附上代码:

 

#include<bits/stdc++.h>
using namespace std;
const int N=3e5+12;
const int mod=1e9+7;
int n,m;

int fa[2*N];
int head[N],net[N*2],to[N*2],p;
void addedge(int a,int b)
{
    to[p]=b;
    net[p]=head[a];
    head[a]=p++;
}

int deep[N],g[N][20];
void DFS(int u,int f)
{
    deep[u]=deep[f]+1;
    g[u][0]=f;
    for(int i=head[u];i!=-1;i=net[i])
    {
        int v=to[i];
        if(v==f) continue;
        DFS(v,u);
    }
}
void init()
{
    for(int j=1;(1<<j)<=n;j++) 
        for(int i=1;i<=n;i++) 
        if(g[i][j-1]!=-1) g[i][j]=g[g[i][j-1]][j-1];
}



inline int lca(int x,int y)
{
    int log=0;
    if(deep[x]<deep[y]) swap(x,y);
    for(log=1;(1<<log)<=deep[x];log++);log--;
    for(int i=log;i>=0;i--) 
    {
        if(deep[x]-(1<<i)>=deep[y]) x=g[x][i];
    }
    if(x==y) return x;
    for(int i=log;i>=0;i--)    
    {
        if(g[x][i]!=g[y][i]&&g[x][i]) 
        {
            x=g[x][i];y=g[y][i];
        }
    }
    if(g[x][0]) x=g[x][0];
    return x;
}
int x[N],y[N],LCA[N];
inline int find(int x) {if(fa[x]==x) return x;else return fa[x]=find(fa[x]);}


void merge(int u,int Lca)
{
    while(deep[g[u][0]]>deep[Lca])
    {
        int f=g[u][0];
        fa[find(u)]=find(f);
        fa[find(u+n)]=find(f+n);
        u=find(f);
    }
}
void Merge(int u,int v)
{
    fa[find(u+n)]=find(v);
    fa[find(v+n)]=find(u);
}

long long qpow(int a,long long b)
{
    long long ans=1LL;
    while(a)
    {
        if(a&1) ans=(ans*b)%mod;
        b=(b*b)%mod;
        a>>=1;
    }
    return ans;
}

int main()
{
    int Size=32<<20;
    char *p=(char*)malloc(Size)+Size;
    __asm__("movl %0, %%esp\n" :: "r"(p));
    freopen("usmjeri.in","r",stdin);
//    freopen("usmjeri.out","w",stdout);
    scanf("%d%d",&n,&m);
    memset(g,-1,sizeof(g));
    memset(head,-1,sizeof(head));
    for(int i=1;i<=2*n;i++) fa[i]=i;
    int a,b;
    for(int i=1;i<n;i++) 
    {
        scanf("%d%d",&a,&b);
        addedge(a,b);addedge(b,a);
    }
    deep[1]=1;
    DFS(1,1);init();
    
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d",&x[i],&y[i]);
        LCA[i]=lca(x[i],y[i]);
        merge(x[i],LCA[i]);
        merge(y[i],LCA[i]);
    }
    for(int i=1;i<=m;i++)
    {
        if(LCA[i]!=x[i]&&LCA[i]!=y[i])
        Merge(x[i],y[i]);
    }
    for(int i=1;i<=n;i++)
    {
        if(find(i)==find(i+n))
        {
            printf("0\n");
            fclose(stdin);fclose(stdout);return 0;
        }
    }
    
    int tot=0;int k=0;
    for(int i=2;i<=n;i++) if(fa[i]==i) tot++;printf("%d ",tot);
    for(int i=n+2;i<=2*n;i++) if(fa[i]==i) tot++,k++;printf("%d\n",k);
    tot/=2;
    printf("%lld\n",qpow(tot,2));
    fclose(stdin);
    fclose(stdout);
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Heey/p/8992539.html
今日推荐