codeforces 993E FFT

题目大意

给一个长度为 n 的序列和 x ,求小于 x 的数恰好有 k 个的区间个数, k 要取遍 1 n

解题思路

显然将所有小于 x 的数转化为1,其他的为0,这就是个01序列,然后求区间和是 k 的区间个数。

那么我们首先要取一个左端点,再取一个右端点。

k 要取遍,并且是 10 5 级的数据,考虑FFT。

考虑左端点取每一个点时,在这个点左边的 1 的个数,若个数为 t ,则 x t 的系数加1,这样构造一个多项式 A

右端点取每一个点时,在这个点右边的 1 的个数,构造一个类似的多项式 B

那么 A B 做卷积,第 t 项的系数就是区间外 1 的个数为 t 的区间个数。

桥豆麻袋,有可能出现右端点取得比左端点小的情况。

发现这种情况下,算出来在区间外 1 的个数一定大于等于整个序列 1 的个数,所以受影响的只有 k = 0 的情况,所以单独用组合的方法计算一下 k = 0 的情况即可。

代码

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=0,w=1;char ch=' ';
    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
    if(ch=='-') w=-1,ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q*w;
}
typedef long long LL;
typedef double db;
const db pi=3.1415926535897384626;
const int N=524300;
struct com{db r,i;}a[N],b[N];
int n,X,tot,len,kn;
int sum[N],rev[N];LL ans[N];
com operator* (com a,com b) {return (com){a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r};}
com operator- (com a,com b) {return (com){a.r-b.r,a.i-b.i};}
com operator+ (com a,com b) {return (com){a.r+b.r,a.i+b.i};}
com operator/ (com a,db b) {return (com){a.r/b,a.i/b};}
void NTT(com *a,int n,int x) {
    for(RI i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
    for(RI i=1;i<n;i<<=1) {
        com wn=(com){cos(pi/i),sin(pi/i)*x};
        for(RI j=0;j<n;j+=(i<<1)) {
            com t1,t2,w=(com){1,0};
            for(RI k=0;k<i;++k,w=w*wn)
                t1=a[j+k],t2=w*a[j+i+k],a[j+k]=t1+t2,a[j+i+k]=t1-t2;
        }
    }
    if(x==-1) for(RI i=0;i<n;++i) a[i]=a[i]/(db)n;
}
int main()
{
    n=read(),X=read();
    for(RI i=1;i<=n;++i) sum[i]=sum[i-1]+(read()<X);
    for(RI i=1;i<=n;++i) ++a[sum[i-1]].r;
    for(RI i=1;i<=n;++i) ++b[sum[n]-sum[i]].r;
    kn=1;while(kn<=sum[n]*2) kn<<=1,++len;
    for(RI i=1;i<kn;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    NTT(a,kn,1),NTT(b,kn,1);
    for(RI i=0;i<kn;++i) a[i]=a[i]*b[i];
    NTT(a,kn,-1);
    int js=0;
    for(RI i=1;i<=n;++i)
        if(sum[i]==sum[i-1]) ++js;
        else ans[0]+=1LL*js*(js-1)/2+js,js=0;
    ans[0]+=1LL*js*(js-1)/2+js;
    for(RI i=0;i<sum[n];++i) ans[sum[n]-i]=(LL)(a[i].r+0.2);
    for(RI i=0;i<=n;++i) printf("%lld ",ans[i]);
    puts("");
    return 0;
}

猜你喜欢

转载自blog.csdn.net/litble/article/details/80963180
FFT