[CodeChef] COUNTARI

Problem Description

Given an array A[] of length N, find how many pairs of i, j, k (1<=i

input format

The first line contains an integer N (N<=10^5).
The next line contains N numbers A[i] (A[i]<=30000).

output format

One integer per line.

sample input

10
3 5 3 6 3 4 10 4 5 2

Sample output

9


answer

PS: The "intermediate number" mentioned later is A[j] when A[k]-A[j]=A[j]-A[i]
Obviously, for each number A[i], if we put The values ​​from A[1] to A[i-1] are put into the array c1[] of size 30000 (that is, c1[j] records how many numbers with the value j in the first i-1 numbers), put The values ​​from A[i+1] to A[n] are put into c2[], then, after c1 and c2 are FFTed, c3[2*A[i]] is the answer with i as the middle number (because c3 [2*a[i]]= c1[l]*c2[k](l+k==2*A[i])).
This stuff is excellent. We can convert the time complexity from O(n 2 ) rises to O(n 2 l O g 2 n ). . .
Think, why is it slow after doing a good FFT? Because after each FFT we only use the number in one position! ! !
That is, we want to get a little more information after each FFT.
on the advice of the boss, using chunking.
We set the length of each block to be s, and divide the sequence into n/s blocks.
For each number a[i], the area of ​​possible answers is divided into:
write picture description here
1 for the left outside the block, 2 for the left inside the block, 3 for the right inside the block, and 4 for the right outside the block.
Therefore, set A[i] as the middle number, and the other two numbers are in the range of (1,3), (1,4), (2,3), (2,4).
(1,4) can be done with FFT, and the other three can be done with the normal block idea.
Total time complexity O(n*s+Max_v l O g 2 M a x V ). But in fact, the FFT constant is larger, and the length of each block is preferably larger.
PS:pit?
1. Self-made: FFT must be played well;
2. Do not open a two-dimensional array, the space of a complex_double is 16B. . . will G


code

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <complex>
#define db double
#define cd complex<db>
#define ll long long
using namespace std;
const double pi=3.1415926535897932;
cd w[70001];
ll n,s;
void fly(cd a[],ll f){
    ll i,j,k,m;
    for(i=j=0;i<n;i++){
        if(i<j)swap(a[i],a[j]);
        for(k=n>>1;(j^=k)<k;k>>=1);
    }
    w[0]=1;
    for(m=1;m<n;m<<=1){
        cd ha=exp(cd(0,pi*(db)f/(db)m));
        for(i=1;i<m;i++)w[i]=w[i-1]*ha;
        for(i=0;i<n;i+=(m<<1))
            for(j=0;j<m;j++){
                cd p=a[i+j],q=a[i+j+m]*w[j];
                a[i+j]=p+q,a[i+j+m]=p-q;
            }
    }
    if(f==1)return;
    cd tmp=1.0/(db)n;
    for(i=0;i<n;i++)
        a[i]*=tmp;
}
ll gg(cd x)
{return (ll)floor(x.real()+0.5);}
cd ti[70001];
void FFT(cd a[],cd b[],cd c[]){
    for(ll i=0;i<n;i++)ti[i]=b[i],c[i]=a[i];
    fly(c,1),fly(ti,1);
    for(ll i=0;i<n;i++)c[i]*=ti[i];
    fly(c,-1);
}
cd ge[70001],now[70001],tmp[70001];
ll v[100005];
int main()
{
    ll i,k,j,ans=0,all,maxc=0;
    scanf("%lld",&all);
    s=(ll)sqrt(all)*3;
    for(i=1;i<=all;i++)scanf("%lld",&v[i]),maxc=max(maxc,v[i]);
    for(n=1;n<=maxc*2;n<<=1);
    ll lb=(all-1)/s+1;
    for(i=1;i<=all;i++)now[v[i]]+=1;
    for(i=1;i<=lb;i++){
        for(j=(i-1)*s+1;j<=i*s&&j<=all;j++)
            now[v[j]]-=1;
        FFT(now,ge,tmp);
        for(j=(i-1)*s+1;j<=i*s&&j<=all;j++)
            for(k=j+1;k<=i*s&&k<=all;k++)
                if(2*v[j]-v[k]>=0)ans+=gg(ge[2*v[j]-v[k]]);
        for(j=(i-1)*s+1;j<=i*s&&j<=all;j++)
            ans+=gg(tmp[v[j]*2]);
        for(j=(i-1)*s+1;j<=i*s&&j<=all;j++)
            ge[v[j]]+=1;
    }
    for(i=0;i<=maxc;i++)now[i]=0;
    for(i=all;i;--i){
        ll nb=(i-1)/s+1;
        for(j=i-1;j>=(nb-1)*s+1;--j)
            if(2*v[i]-v[j]>=0)ans+=gg(now[2*v[i]-v[j]]);
        now[v[i]]+=1;
    }
    printf("%lld",ans);
    return 0;
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325652220&siteId=291194637