字节--异或

字节–异或

一、题目描述

给定整数m以及n各数字A1,A2,..An,将数列A中所有元素两两异或,共能得到n(n-1)/2个结果,请求出这些结果中大于m的有多少个

  • 输入描述:
第一行包含两个整数n,m. 

第二行给出n个整数A1,A2,...,An。

数据范围

对于30%的数据,1 <= n, m <= 1000

对于100%的数据,1 <= n, m, Ai <= 10^5
  • 输出描述:
输出仅包括一行,即所求的答案

输入例子1:
3 10  
6 5 10

输出例子1:
2

二、分析

首先这道题肯定不能考虑n2的解法,肯定会超时,所以解决这道题还需另辟捷径

  • 直接计算肯定是超时的,所以这问题不能使用暴力破解,考虑到从高位到地位,依次进行位运算
  • 如果两个数异或结果在某高位为1,而m的对应位为0,则肯定任何这两位异或结果为1的都会比m大
  • 由此,考虑使用字典树(TrieTree)从高位到低位建立字典再使用每个元素依次去字典中查对应高位异或为1, 而m为0的数的个数,相加在除以2既是最终的结果
  • 我想说明下数据范围,n,m,Ai都是[1,1e5]的,(1 << 17)>1e5,所以一个数至少要17位来存储,所以trie树的节点个数就是1e5*17,这个不理解的话,仔细查看一下trie树的资料吧。
  • TrieTree在搜索的过程中,是从高位往低位搜索,那么,如果有一个数与字典中的数异或结果的第k位大于m的第k位,那么该数与对应分支中所有的数异或结果都会大于m
  • 否则,就要搜索在第k位异或相等的情况下,更低位的异或结果。TrieTree中四个分支的作用分别如下:
  1. aDigit=1, mDigit=1时,字典中第k位为0,异或结果为1,需要继续搜索更低位,第k位为1,异或结果为0,小于mDigit,不用理会
  2. aDigit=0, mDigit=1时,字典中第k位为1,异或结果为1,需要继续搜索更低位,第k位为0,异或结果为0,小于mDigit,不用理会
  3. aDigit=1, mDigit=0时,字典中第k位为0,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为1,异或结果为0,递归获得结果
  4. aDigit=0, mDigit=0时,字典中第k位为1,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为0,异或结果为0,递归获得结果
  • 对于我来说,关键点在于构建字典树

改进:

  • 1.字典树17位即可保证大于100000,移位范围为1~16位,则字典树构建时从16~0即可。

    字典树第一层不占位,实际上是15~-1层有数据,这也是数据中next的用法

  • 2.queryTrieTree函数需要考虑到index为-1时的返回值。

  • 时间复杂度:O(n);

  • 空间复杂度O(k),k为常数(trie树的高度),因此可以认为O(1)。

三、代码

#include <iostream>
#include <vector>
using namespace std;
 
//新建结构体,用来构造字典
struct TrieTree
{
    int count;//每个节点存的次数
    struct TrieTree* next[2]{NULL,NULL};//每个节点存储两个节点指针
    TrieTree():count(1)
    {}
};
 
//构造字典
TrieTree* buildTrieTree(const vector<int>& array)
{
    TrieTree* trieTree = new TrieTree();
    for(int i = 0;i < (int)array.size();++i)
    {
        TrieTree* cur = trieTree;//从根节点开始
        for(int j = 16;j >= 0;--j)
        {
            int digit = (array[i] >> j) & 1;
            if(NULL == cur->next[digit])
                cur->next[digit] = new TrieTree();
            else
                ++(cur->next[digit]->count);
            cur = cur->next[digit];
        }
    }
    return trieTree;
}
 
//查询字典树,查询 ai与字典中数异或大于m的数
long long queryTrieTree(TrieTree*& trieTree, const int a, const int m, const int index)
{
    if(NULL == trieTree)
        return 0;
 
    TrieTree* cur = trieTree;
     
    for(int i = index;i >= 0;--i)
    {
    	//m当前位为1则 只能 ai与不同的位选择才有机会
        int aDigit = (a >> i) & 1;
        int mDigit = (m >> i) & 1;
 
        if(1 ==a Digit && 1 == mDigit)
        {
            if(NULL == cur->next[0])
                return 0;
            cur = cur->next[0];
        }
        else if(0 == aDigit && 1 == mDigit)
        {
            if(NULL == cur->next[1])
                return 0;
            cur = cur->next[1];
        }
        else if(1 == aDigit && 0 == mDigit)
        {
            long long val0 =  (NULL == cur->next[0]) ? 0 : cur->next[0]->count;
            long long val1 =  queryTrieTree(cur->next[1],a,m,i - 1);
            return val0 + val1;
        }
        else if(0 == aDigit && 0 == mDigit)
        {
            long long val0 =  queryTrieTree(cur->next[0],a,m,i - 1);
            long long val1 =  (NULL == cur->next[1]) ? 0 : cur->next[1]->count;
            return val0 + val1;
        }
    }   
    return 0;//此时index==-1,这种情况肯定返回0(其他情况在循环体中都考虑到了)
}
 
//结果可能超过了int范围,因此用long long
long long solve(const vector<int>& array, const int& m)
{
    TrieTree* trieTree = buildTrieTree(array);
    long long result = 0;
    for(int i = 0;i < (int)array.size();++i)
    {
        result += queryTrieTree(trieTree,array[i],m,16);
    }
    return result / 2;
}
 
int main()
{
    int n,m;
    while(cin>>n>>m)
    {
        vector<int> array(n);
        for(int i = 0;i < n;++i)
            cin>>array[i];
        cout<< solve(array,m) <<endl;
    }
    return 0;
}

在看个大佬的解法:

#include <bits/stdc++.h>
#define X first
#define Y second
#define MP make_pair
#define PB push_back
#define SZ(X) ((int)(X).size())
#define ALL(X) (X).begin(), (X).end()
#define SORT_UNIQUE(c) (sort(c.begin(),c.end()), c.resize(distance(c.begin(),unique(c.begin(),c.end()))))
#define MS0(X) memset((X), 0, sizeof((X)))
#define MS1(X) memset((X), -1, sizeof((X)))
#define LEN(X) strlen(X)
#define FIO ios::sync_with_stdio(false);
#define bug(x) cout<<"bug("<<x<<")"<<endl;
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef long double LD;
typedef pair<int,int> PII;
typedef pair<int,PII> N;
typedef vector<int> VI;
typedef vector<LL> VL;
typedef vector<PII> VPII;
const int maxn=1e6+10;
 
using namespace std;
void FWT(LL a[],int n)
{
    for(int d = 1;d < n;d <<= 1)
        for(int m = d << 1,i = 0;i < n;i += m)
            for(int j = 0;j < d;j++)
            {
                LL x = a[i + j],y = a[i + j + d];
                a[i + j] = (x + y),a[i + j + d] = (x - y);
            }
}
 
void UFWT(LL a[],int n)
{
    for(int d = 1;d < n;d <<=1 )
        for(int m = d << 1,i = 0;i < n;i += m)
            for(int j = 0;j < d;j++)
            {
                LL x = a[i + j],y = a[i + j + d];
                a[i + j] = 1LL * (x + y) / 2,a[i + j + d] = (1LL * (x - y) / 2);
            }
}

void solve(LL a[],LL b[],int n)
{
    FWT(a,n);
    FWT(b,n);
    for(int i = 0;i < n;i++) 
    	a[i] = 1LL * a[i] * b[i];
    UFWT(a,n);
}

LL a[maxn],b[maxn];
int main()
{
    int n,m,x;
    FIO
    cin>>n>>m;
    for(int i = 0;i < n;i++)
    	cin>>x,++a[x],++b[x];
    solve(a,b,1 << 17);
    LL ans = 0;
    for(int i = m + 1;i < (1 << 17);i++)
    	ans += a[i];
    cout<<ans / 2<<endl;
}

猜你喜欢

转载自blog.csdn.net/wolfGuiDao/article/details/106725032