字节–异或
一、题目描述
给定整数m以及n各数字A1,A2,..An
,将数列A中所有元素两两异或
,共能得到n(n-1)/2个结果,请求出这些结果中大于m的有多少个
。
-
输入描述:
第一行包含两个整数n,m.
第二行给出n个整数A1,A2,...,An。
数据范围
对于30%的数据,1 <= n, m <= 1000
对于100%的数据,1 <= n, m, Ai <= 10^5
-
输出描述:
输出仅包括一行,即所求的答案
输入例子1:
3 10
6 5 10
输出例子1:
2
二、分析
首先这道题肯定不能考虑n2的解法,肯定会超时,所以解决这道题还需另辟捷径
- 直接计算肯定是超时的,所以这问题
不能使用暴力破解
,考虑到从高位到地位,依次进行位运算, -
如果两个数异或结果在某高位为1,而m的对应位为0
,则肯定任何这两位异或结果为1的都会比m大
。 - 由此,考虑
使用字典树(TrieTree)从高位到低位建立字典
,再使用每个元素依次去字典中查对应高位异或为1, 而m为0的数的个数,相加在除以2既是最终的结果
; - 我想说明下数据范围,n,m,Ai都是[1,1e5]的,(1 << 17)>1e5,
所以一个数至少要17位来存储
,所以trie树的节点个数就是1e5*17,这个不理解的话,仔细查看一下trie树的资料吧。 - TrieTree在搜索的过程中,是
从高位往低位搜索
,那么,如果有一个数与字典中的数异或结果的第k位大于m的第k位,那么该数与对应分支中所有的数异或结果都会大于m
-
否则,就要搜索在第k位异或相等的情况下,更低位的异或结果
。TrieTree中四个分支的作用分别如下:
- aDigit=1, mDigit=1时,
字典中第k位为0,异或结果为1,需要继续搜索更低位,第k位为1,异或结果为0,小于mDigit,不用理会
; - aDigit=0, mDigit=1时,
字典中第k位为1,异或结果为1,需要继续搜索更低位,第k位为0,异或结果为0,小于mDigit,不用理会
; - aDigit=1, mDigit=0时,
字典中第k位为0,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为1,异或结果为0,递归获得结果
; - aDigit=0, mDigit=0时,
字典中第k位为1,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为0,异或结果为0,递归获得结果
;
- 对于我来说,关键点在于构建字典树
改进:
-
1.字典树17位即可保证大于100000,移位范围为1~16位,则字典树构建时从16~0即可。
字典树第一层不占位,实际上是15~-1层有数据,这也是数据中next的用法
。 -
2.queryTrieTree函数需要考虑到index为-1时的返回值。
-
时间复杂度:O(n);
-
空间复杂度O(k),k为常数(trie树的高度),因此可以认为O(1)。
三、代码
#include <iostream>
#include <vector>
using namespace std;
//新建结构体,用来构造字典
struct TrieTree
{
int count;//每个节点存的次数
struct TrieTree* next[2]{NULL,NULL};//每个节点存储两个节点指针
TrieTree():count(1)
{}
};
//构造字典
TrieTree* buildTrieTree(const vector<int>& array)
{
TrieTree* trieTree = new TrieTree();
for(int i = 0;i < (int)array.size();++i)
{
TrieTree* cur = trieTree;//从根节点开始
for(int j = 16;j >= 0;--j)
{
int digit = (array[i] >> j) & 1;
if(NULL == cur->next[digit])
cur->next[digit] = new TrieTree();
else
++(cur->next[digit]->count);
cur = cur->next[digit];
}
}
return trieTree;
}
//查询字典树,查询 ai与字典中数异或大于m的数
long long queryTrieTree(TrieTree*& trieTree, const int a, const int m, const int index)
{
if(NULL == trieTree)
return 0;
TrieTree* cur = trieTree;
for(int i = index;i >= 0;--i)
{
//m当前位为1则 只能 ai与不同的位选择才有机会
int aDigit = (a >> i) & 1;
int mDigit = (m >> i) & 1;
if(1 ==a Digit && 1 == mDigit)
{
if(NULL == cur->next[0])
return 0;
cur = cur->next[0];
}
else if(0 == aDigit && 1 == mDigit)
{
if(NULL == cur->next[1])
return 0;
cur = cur->next[1];
}
else if(1 == aDigit && 0 == mDigit)
{
long long val0 = (NULL == cur->next[0]) ? 0 : cur->next[0]->count;
long long val1 = queryTrieTree(cur->next[1],a,m,i - 1);
return val0 + val1;
}
else if(0 == aDigit && 0 == mDigit)
{
long long val0 = queryTrieTree(cur->next[0],a,m,i - 1);
long long val1 = (NULL == cur->next[1]) ? 0 : cur->next[1]->count;
return val0 + val1;
}
}
return 0;//此时index==-1,这种情况肯定返回0(其他情况在循环体中都考虑到了)
}
//结果可能超过了int范围,因此用long long
long long solve(const vector<int>& array, const int& m)
{
TrieTree* trieTree = buildTrieTree(array);
long long result = 0;
for(int i = 0;i < (int)array.size();++i)
{
result += queryTrieTree(trieTree,array[i],m,16);
}
return result / 2;
}
int main()
{
int n,m;
while(cin>>n>>m)
{
vector<int> array(n);
for(int i = 0;i < n;++i)
cin>>array[i];
cout<< solve(array,m) <<endl;
}
return 0;
}
在看个大佬的解法:
#include <bits/stdc++.h>
#define X first
#define Y second
#define MP make_pair
#define PB push_back
#define SZ(X) ((int)(X).size())
#define ALL(X) (X).begin(), (X).end()
#define SORT_UNIQUE(c) (sort(c.begin(),c.end()), c.resize(distance(c.begin(),unique(c.begin(),c.end()))))
#define MS0(X) memset((X), 0, sizeof((X)))
#define MS1(X) memset((X), -1, sizeof((X)))
#define LEN(X) strlen(X)
#define FIO ios::sync_with_stdio(false);
#define bug(x) cout<<"bug("<<x<<")"<<endl;
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef long double LD;
typedef pair<int,int> PII;
typedef pair<int,PII> N;
typedef vector<int> VI;
typedef vector<LL> VL;
typedef vector<PII> VPII;
const int maxn=1e6+10;
using namespace std;
void FWT(LL a[],int n)
{
for(int d = 1;d < n;d <<= 1)
for(int m = d << 1,i = 0;i < n;i += m)
for(int j = 0;j < d;j++)
{
LL x = a[i + j],y = a[i + j + d];
a[i + j] = (x + y),a[i + j + d] = (x - y);
}
}
void UFWT(LL a[],int n)
{
for(int d = 1;d < n;d <<=1 )
for(int m = d << 1,i = 0;i < n;i += m)
for(int j = 0;j < d;j++)
{
LL x = a[i + j],y = a[i + j + d];
a[i + j] = 1LL * (x + y) / 2,a[i + j + d] = (1LL * (x - y) / 2);
}
}
void solve(LL a[],LL b[],int n)
{
FWT(a,n);
FWT(b,n);
for(int i = 0;i < n;i++)
a[i] = 1LL * a[i] * b[i];
UFWT(a,n);
}
LL a[maxn],b[maxn];
int main()
{
int n,m,x;
FIO
cin>>n>>m;
for(int i = 0;i < n;i++)
cin>>x,++a[x],++b[x];
solve(a,b,1 << 17);
LL ans = 0;
for(int i = m + 1;i < (1 << 17);i++)
ans += a[i];
cout<<ans / 2<<endl;
}