CodeChef - COUNTARI Arithmetic Progressions FFT 分块

这题因为ijk大小关系的限制,所以不能像三个傻瓜那题一样直接FFT,排序后排出情况。

所以一开始想到的是对每个位置都做一次FFT,即枚举 Aj ,用 AiAk 做FFT,但这复杂度明显是不行的 O(N30000log30000)

然后看了题解才知道还有分块这种方法。。

具体的分法就不说了,网上有一大堆,最后的复杂度就是 O(N2K+k3000016)
k取到30就能过了,不过为什么看着感觉k取到1000才能过呢。。

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;
const int MAXN = 262144;
const int LIM = 61000;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const double pi=acos(-1.0);
int num[MAXN],n,block,size;
LL pre[LIM],in[LIM],nex[LIM];
LL ans = 0;

struct cp
{
    double x,y;
    cp() {}
    cp(double x,double y):x(x),y(y) {}
    inline double real() { return x; }
    inline cp operator * (const cp& r) const { return cp(x*r.x-y*r.y,x*r.y+y*r.x); }
    inline cp operator - (const cp& r) const { return cp(x-r.x,y-r.y); }
    inline cp operator + (const cp& r) const { return cp(x+r.x,y+r.y); }
};

cp a[MAXN],b[MAXN];
LL r[MAXN],res[MAXN];
LL ax[MAXN],bx[MAXN];

void fft_init(int nm,int k)
{
    for (int i=0;i<nm;i++) r[i] = (r[i>>1]>>1) | ((i&1) << (k-1));
}

void fft(cp ax[],int nm,int op)
{
    for (int i=0;i<nm;i++) if (i<r[i]) swap(ax[i],ax[r[i]]);
    for (int h=2,m=1;h<=nm;h<<=1,m<<=1)
    {
        cp wn = cp(cos(op*2*pi/h),sin(op*2*pi/h));
        for (int i=0;i<nm;i+=h)
        {
            cp w(1,0);
            for (int j=i;j<i+m;++j,w=w*wn)
            {
                cp t=w*ax[j+m];
                ax[j+m] = ax[j]-t;
                ax[j] = ax[j]+t;
            }
        }
    }
    if (op==-1) for (int i=0;i<nm;i++) ax[i].x /= nm;
}

void trans(LL ax[],LL bx[],int n,int m)
{
    int nm=1,k=0;
    while (nm < 2*n || nm<2*m) nm<<=1,k++;

    for (int i=0;i<n;i++) a[i] = cp(ax[i],0);
    for (int i=0;i<m;i++) b[i] = cp(bx[i],0);
    for (int i=n;i<nm;i++) a[i] = cp(0,0);
    for (int i=m;i<nm;i++) b[i] = cp(0,0);

    fft_init(nm,k);
    fft(a,nm,1);fft(b,nm,1);
    for (int i=0;i<nm;i++) a[i] = a[i]*b[i];
    fft(a,nm,-1);
    nm = n+m-1;
    for (int i=0;i<nm;i++) res[i] = (LL)(a[i].real()+0.5);
}

int main()
{
    while (scanf("%d",&n)!=EOF)
    {
        memset(in,0,sizeof in);
        memset(pre,0,sizeof pre);
        memset(nex,0,sizeof nex);
        for (int i=1;i<=n;i++)
        {
            scanf("%d",&num[i]);
            nex[num[i]] ++;
        }
        block = 30;
        size = (n+block-1)/block;
        for (int b=1;b<=block;b++)
        {
            int s=(b-1)*size +1,e=min(b*size,n);
            for (int i=s;i<=e;i++) nex[num[i]]--;
            trans(pre,nex,30001,30001);
            for (int i=s;i<=e;i++)
            {
                for (int j=i+1;j<=e;j++)
                {
                    if (2*num[i] - num[j]>=1 )
                    {
                        ans += in[ 2*num[i] - num[j] ];//3 in
                        ans += pre[ 2*num[i] - num[j] ];//2 in  1 prev
                    }
                    if ( 2*num[j]-num[i]>=1 )
                        ans += nex[ 2*num[j]-num[i] ];
                }
                ans += res[ 2*num[i] ];
                in [ num[i] ] ++;
            }
            for (int i=s;i<=e;i++) pre[num[i]]++,in[num[i]]--;
        }
        printf("%lld\n",ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/z631681297/article/details/78144687