FFT+容斥+生成函数——hdu4609

/*
生成函数+FFT/NTT+容斥计数
求(i,j,k)组合,使得a[i],a[j],a[k]可以组成三角形
先给a排个序,然后用cnt数组统计,
假设以每条边a[k]作为最大边(不是数值最大,是排序后位置最靠右),剩下两条边就要满足a[i]+a[j]>a[k]
那么满足这个条件的对数就可以通过FFT求出
    F[k]=sum{cnt[i],cnt[k-i]}表示凑成长度k的组合对数
    F[]里减掉a[i]+a[i]类的贡献,再除2 
    给F[]求个前缀和sum[]
    遍历a[i],求sum[n]-sum[a[i]]
    再容斥掉多算的部分
        有一条在a[i]右边的边
        有两条在a[i]右边的边
        有一条边就是a[i]自己 
*/
#include<bits/stdc++.h>
using namespace std;
#define N 500005
#define ll long long

typedef complex<double> cp;
const double pi = acos(-1.0); 
ll n,m,A[N],cnt[N],F[N],sum[N];
cp a[N],b[N]; 

void FFT(cp *a,int opt)
{
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(register int i = 0; i < n; i++){
        int t = 0;
        for(register int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
    }
    for (int k=1;k<n;k<<=1)
    {
        cp wn=cp(cos(pi/k),opt*sin(pi/k));
        for (int i=0;i<n;i+=(k<<1))
        {
            cp w=cp(1,0);
            for (int j=0;j<k;++j,w=w*wn)
            {
                cp x=a[i+j],y=w*a[i+j+k];
                a[i+j]=x+y,a[i+j+k]=x-y;
            }
        }
    }
}
void init(){
    memset(a,0,sizeof a);
    memset(b,0,sizeof b);
    memset(F,0,sizeof F);
    memset(sum,0,sizeof sum);
    memset(cnt,0,sizeof cnt);
}

int main(){
    int t;cin>>t;
    while(t--){
        init();
        scanf("%lld",&m);
        ll Max=0;
        for(int i=1;i<=m;i++)
            scanf("%lld",&A[i]),cnt[A[i]]++,Max=max(Max,A[i]);
        n=1;while(n<=Max*2+1)n<<=1;
        sort(A+1,A+1+m);
        
        for(int i=0;i<n;i++)
            a[i].real(cnt[i]),b[i].real(cnt[i]);
        FFT(a,1);FFT(b,1);
        for(int i=0;i<n;i++)a[i]*=b[i];
        FFT(a,-1);
        for(int i=0;i<n;i++)F[i]=(ll)(a[i].real()/n+0.5);
        
        for(int i=1;i<=m;i++)F[2*A[i]]--;
        for(int i=0;i<n;i++)F[i]>>=1;
        sum[0]=F[0]; 
        for(int i=1;i<n;i++)sum[i]=sum[i-1]+F[i];
        
        ll ans=0;
        for(int i=1;i<=m;i++){
            ans+=sum[n-1]-sum[A[i]];
            ans-=(m-i)*(i-1);
            ans-=(m-1);
            ans-=(m-i)*(m-i-1)/2;
        }
        printf("%.7lf\n",1.0*ans/((ll)m*(m-1)*(m-2)/6)); 
    }    
} 
/*
5
10
1 3 3 4 4 5 6 7 7 8
10
1 1 1 1 1 1 1 1 1 1
10 
2 2 2 2 2 3 3 3 3 3
10 
6 6 6 6 6 2 2 2 3 3
*/

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/12309904.html