【洛谷P5283】异或粽子【Trie】【堆】

题目大意:

题目链接:https://www.luogu.org/problemnew/show/P5283
给出一个序列,找到 m m 个区间 [ l i , r i ] [l_i,r_i] 使得这些区间的异或和最大。


思路 :

先做一遍前缀异或,这样问题就被转换成找 m m l , r l,r 使得 l   x o r   r \sum l \ xor\ r 尽量大。
注意到 a [ l ]   x o r   a [ r ] = a [ r ]   x o r   a [ l ] a[l]\ xor\ a[r]=a[r]\ xor\ a[l] ,所以如果选了 l , r l,r ,那么也会选一遍 r , l r,l 。这样的话我们可以将询问次数 × 2 \times 2 ,这样每连续的两组查询就是一组相同的 ( l , r ) ( r , l ) (l,r)(r,l) 。最终答案除以2就行了。
我们维护一个堆,将每一个 i i 所对应的 a [ i ]   x o r   a [ j ] a[i]\ xor\ a[j] 最大的 j j 扔到堆里面,每次询问取堆顶,将下一大的 a [ i ]   x o r   a [ j ] a[i]\ xor\ a[j] 扔到堆中。
对于 x x y y 使得 a [ x ]   x o r   a [ y ] a[x]\ xor\ a[y] 为第 k k 大显然是可以用 T r i e Trie 维护的。每次利用权值来判断往子树的左右。
这样就可以不使用可持久化 T r i e Trie 来完成这道题了。时间复杂度 O ( n log n ) O(n\log n)


代码:

#include <queue>
#include <cstdio>
#include <string>
#include <iostream>
#define mp make_pair
using namespace std;
typedef long long ll;

const int N=500010,LG=35;
int n,m,tot=1,trie[N*LG][2],cnt[N],size[N*LG];
ll a[N],ans;
priority_queue<pair<ll,int> > q;

ll read()
{
    ll d=0;
    char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    while (isdigit(ch))
        d=(d<<3)+(d<<1)+(ll)ch-48LL,ch=getchar();
    return d;
}

void insert(ll x)
{
    int p=1;
    for (int i=LG;i>=0;i--)
    {
        int id=(x>>(ll)i)&1;
        if (!trie[p][id]) trie[p][id]=++tot;
        p=trie[p][id];
        size[p]++;
    }
}

ll find(ll x,int k)
{
    int p=1; 
    ll ans=0;
    for (int i=LG;i>=0;i--)
    {
        int id=(x>>(ll)i)&1;
        if (trie[p][id^1]&&size[trie[p][id^1]]>=k)
        {
            ans=(ans<<1)|1;
            p=trie[p][id^1];
        }
        else if (trie[p][id])
        {
            k-=size[trie[p][id^1]];
         	ans<<=1;
         	p=trie[p][id];
        }
    }
    return ans;
}

int main()
{
    scanf("%d%d",&n,&m);
    insert(0); cnt[0]=1;
    for (int i=1;i<=n;i++)
    {
        a[i]=read();
        a[i]^=a[i-1];
        insert(a[i]);
        cnt[i]=1;
    }
    for (int i=0;i<=n;i++)
        q.push(mp(find(a[i],1),i));
    m*=2;
    while (m--)
    {
    	if (!q.size()) break;
        ans+=q.top().first;
        int i=q.top().second;
        q.pop();
        cnt[i]++;
        if (cnt[i]<=n+1)
            q.push(mp(find(a[i],cnt[i]),i));
    }
    cout<<ans/2LL;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/SSL_ZYC/article/details/94644518