【POJ 3764】The xor-longest Path【Trie】

题目大意:

题目链接:http://poj.org/problem?id=3764
在一棵树中选择任意两个结点,使得他们之间路径之和最大。


思路:

首先我们设点1为根节点,然后求每个节点与根节点(点1)的路径的异或值。那么设 d [ i ] 为点1到点 i 之间路径异或值,那么必然有

d [ i ] = d [ f a t h e r ]     x o r     d i s [ i ] [ f a t h e r ]

那么在根据 a     x o r     a = 0 ,那么很容易得到点 i 和点 j 的异或值为 d [ i ]     x o r     d [ j ] ,因为如果它们在根结点的两侧,那么上述公式十分显然。如果他们在同一侧,那么若有一条重复路径 i ,那么就绝对会有 d [ i ]     x o r     d [ i ] ,那么就为0,相当于低抵消了。
那么题目就变成了求一个数列中的任意两个数字亦或的最大值,就根 这道题完全一样了。


代码:

#include <cstdio>
#include <cstring>
#include <iostream>
#define N 200005
#define M 3000005
#define up 30
using namespace std;

int n,x,y,k,tot=1,head[N],trie[M][2];
long long d[N],z,ans,sum;

struct edge  //邻接表
{
    int to,next;
    long long dis;
}e[N];

void add(int from,int to,long long dis)  //建图(树)
{
    k++;
    e[k].to=to;
    e[k].dis=dis;
    e[k].next=head[from];
    head[from]=k;
}

void dfs(int x,int fa)  //深搜求点1和点i的异或距离
{
    int v;
    for (int i=head[x];i;i=e[i].next)
    {
        v=e[i].to;
        if (v==fa) continue;  //不能到父节点
        d[v]=d[x]^e[i].dis;  //求距离
        dfs(v,x);
    } 
}

void insert(long long x)  //插入
{
    int p=1;
    for (int i=up;i>=0;i--)
    {
        int id=(x>>i)&1;
        if (!trie[p][id]) trie[p][id]=++tot;
        p=trie[p][id];
    }
}

void find(long long x)  //查找
{
    int p=1;
    for (int i=up;i>=0;i--)
    {
        int id=(x>>i)&1;
        if (trie[p][id^1])
        {
            sum=sum*2+1;
            p=trie[p][id^1];
        }
        else
         if (trie[p][id])
         {
            sum=sum*2;
            p=trie[p][id];
         }
        //else return;
    }
}

int main()
{
    while (scanf("%d",&n)==1)
    {
        memset(head,0,sizeof(head));
        memset(d,0,sizeof(d));
        memset(trie,0,sizeof(trie));
        tot=1;
        k=0;
        ans=0;  //初始化
        for (int i=1;i<n;i++)
        {
            scanf("%d%d%lld",&x,&y,&z);
            x++;
            y++;
            add(x,y,z);
            add(y,x,z);
        }
        dfs(1,-1);
        for (int i=1;i<=n;i++) insert(d[i]);
        for (int i=1;i<=n;i++)
        {
            sum=0;
            find(d[i]);
            ans=max(ans,sum);
        }
        printf("%lld\n",ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/SSL_ZYC/article/details/81782163