Arithmetic Progressions
code
#include<bits/stdc++.h>
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
#define ACM_LOCAL
using namespace std;
typedef long long ll;
const double PI = acos(-1);
const double eps = 1e-4;
const int MOD = 1e9 + 7;
const int M = 1e7 + 10;
const int N = 2e5 + 10;
int n, m;
int ans[N];
int Lim = 1, L;
int R[N];
struct Complex {
double x, y;
Complex (double x = 0, double y = 0) : x(x), y(y) {
}
}A[N], B[N];
Complex operator * (Complex J, Complex Q) {
return Complex(J.x * Q.x - J.y * Q.y, J.x * Q.y + J.y * Q.x);
}
Complex operator - (Complex J, Complex Q) {
return Complex(J.x - Q.x, J.y - Q.y);
}
Complex operator + (Complex J, Complex Q) {
return Complex(J.x + Q.x, J.y + Q.y);
}
void FFT(Complex *J, double type) {
for(int i = 0; i < Lim; ++ i)
if(i < R[i]) swap(J[i], J[R[i]]);
for(int mid = 1; mid < Lim; mid <<= 1) {
Complex wn(cos(PI / mid), type * sin(PI / mid));
for(int len = mid << 1, pos = 0; pos < Lim; pos += len) {
Complex w(1, 0);
for(int k = 0; k < mid; ++ k, w = w * wn) {
Complex x = J[pos + k];
Complex y = w * J[pos + mid + k];
J[pos + k] = x + y;
J[pos + mid + k] = x - y;
}
}
}
}
void Conv() {
FFT(A, 1);
FFT(B, 1);
for (int i = 0; i <= Lim; ++ i)
A[i] = A[i] * B[i];
FFT(A, -1);
}
int cnt_a, cnt_b;
int vis[4][N], maxx;
int temp[N];
void solve(ll res = 0) {
int block = min(n, 30);
int one = n / block;
if(n % block) one ++;
for(int idx = 1; idx <= block; ++ idx) {
int begin = one * (idx - 1), end = min(n - 1, one * idx - 1);
for(int i = begin; i <= end; ++ i) vis[3][temp[i]] --;
for(int i = begin; i <= end; ++ i) {
for(int j = i + 1; j <= end; ++ j) {
int an = 2 * temp[i] - temp[j];
if(an >= 1 && an <= maxx) {
res += vis[2][an];
res += vis[1][an];
}
an = 2 * temp[j] - temp[i];
if(an >= 1 && an <= maxx) {
res += vis[3][an];
}
}
vis[2][temp[i]] ++;
}
for(int i = 0; i <= Lim; ++ i) {
if(i <= maxx) {
A[i].x = vis[1][i];
B[i].x = vis[3][i];
A[i].y = B[i].y = 0;
} else A[i].x = A[i].y = B[i].x = B[i].y = 0;
}
Conv();
for(int i = begin; i <= end; ++ i) {
int an = 2 * temp[i];
res += (ll)(A[an].x / Lim + 0.5);
}
for(int i = begin; i <= end; ++ i) {
vis[1][temp[i]] ++;
vis[2][temp[i]] --;
}
}
cout << res << endl;
}
signed main() {
IO;
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
int o = 1, cases = 0;
while(o --) {
cin >> n;
for(int i = 0; i < n; ++ i) {
cin >> temp[i];
vis[3][temp[i]] ++;
maxx = max(maxx, temp[i]);
}
while(Lim <= maxx) Lim <<= 1, L ++ ;
Lim <<= 1;
L ++;
for (int i = 0; i <= Lim; ++ i) {
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
solve();
}
return 0;
}