题目链接:A+B Problem
Given N integers in the range [−50000,50000], how many ways are there to pick three integers ai, aj, ak, such that i, j, k are pairwise distinct and ai+aj=ak? Two ways are different if their ordered triples (i,j,k) of indices are different.
Input
The first line of input consists of a single integer N (1≤N≤200000). The next line consists of N space-separated integers a1,a2,…,aN.
Output
Output an integer representing the number of ways.
Sample Input 1 Sample Output 1
4
1 2 3 4
4
Sample Input 2 Sample Output 2
6
1 1 3 3 4 6
10
先用一个多项式来代表每个数字出现的次数,然后做一次FFT,就可以求出任意两个数字的出现次数了。
然后数字会有负数,所以我们需要先加一个值,让所有的值都变为正数。
这道题还需要考虑出现0 的情况。
如果当前数字为0:那么当前0可能和任意其他0,对答案贡献,并且乘上A(2,2),因为 i,j 和 j,i 。
如果当前数字不为0,那么当前数字可能和任意0组合,对答案贡献。
AC代码:
#pragma GCC optimize("-Ofast","-funroll-all-loops")
#include<bits/stdc++.h>
//#define int long long
using namespace std;
typedef long long ll;
const int N=2e6+10;
const double PI=acos(-1.0);
int n,a[N],base=5e4,vis[N],zero; ll res,sum[N];
int r[N],lim,l;
struct Complex{
double x,y;
}f[N];
Complex operator + (Complex a,Complex b){return {a.x+b.x,a.y+b.y};}
Complex operator - (Complex a,Complex b){return {a.x-b.x,a.y-b.y};}
Complex operator * (Complex a,Complex b){return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
inline void FFT(Complex *a,int n,int k){
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
Complex wn={cos(2*PI/(mid<<1)),k*sin(2*PI/(mid<<1))};
for(int i=0;i<n;i+=(mid<<1)){
Complex w={1,0};
for(int j=0;j<mid;j++,w=(w*wn)){
Complex t0=a[i+j],t1=w*a[i+mid+j];
a[i+j]=t0+t1;
a[i+mid+j]=t0-t1;
}
}
}
if(k==-1) for(int i=0;i<n;i++) a[i].x=a[i].x/n+0.5;
}
inline void init(int n){
lim=1,l=0; while(lim<=(n<<1)) lim<<=1,l++;
for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
signed main(){
cin>>n;
for(int i=1;i<=n;i++){
scanf("%d",&a[i]),vis[a[i]+base]++; if(!a[i]) zero++;
}
init(base<<1);
for(int i=0;i<=lim;i++) f[i].x=1.0*vis[i];
FFT(f,lim,1);
for(int i=0;i<=lim;i++) f[i]=f[i]*f[i];
FFT(f,lim,-1);
for(int i=0;i<=lim;i++) sum[i]=f[i].x;
for(int i=1;i<=n;i++) sum[(a[i]+base)<<1]--;
for(int i=1;i<=n;i++){
res+=sum[a[i]+base*2];
if(a[i]==0) res-=(zero-1)*2;
else res-=zero*2;
}
cout<<res;
return 0;
}