HDU5829 Rikka with Subset

题面

题目描述:

n 个数字 A [ 1 ] ~ A [ n ] ,还有一个数字 K 。对于任意一个 A 的非空子集 S S 的价值等于 S 集合中最大的 m i n ( | S | , K ) 个数之和( | S | 表示集合 S 的大小),数组 A 的价值为所有 A 的非空子集的价值之和。现在给出这 n 个数,他想知道对任意 K 属于 [ 1 , n ] 的数组的价值。

输入:

第一行包含一个数字 t ( 1 t 10 ) 表示数据组数。对于每组数据,第一行有一个整数 n ( 1 n 10 5 ) ,第二行包含 n 个数 A [ 1 ] ~ A [ n ] ( 0 A [ i ] 10 9 )

输出:

对于每组数据,输出一行包含 n 个整数,第 i 个数是当 K = i 时的数组价值。答案可能会很大,所以你只需输出答案对 998244353 取模的结果。

讲解:

直接按照题目要求模拟,每一次询问找出所有的集合,然后把每个集合最大的前 K 个数累加,时间复杂度 O ( 2 n ) ,emmm……,肯定要优化。

首先考虑到是要集合里最大的前几个数,可以将数组 A 排序,思考对于每一个 K 如何求价值,可以发现 K 的答案只是在 K 1 的答案上将所有非空子集的第 K 大的值给加上,而所有非空子集的第 K 大的值又等价于每一个数作为第 K 大的贡献之和,那么就可以通过求出每一个数作为第 K 大的贡献之和,再累加一下求前缀和得出答案,那么令 f [ K ] 为每一个数作为第 K 大的贡献之和,这样就可以列出一个公式:

f [ K ] = i = K n ( i 1 K 1 ) 2 n i A [ i ]

其中 ( i 1 K 1 ) 2 n i 求的是有多少个集合里 A [ i ] 作为第 K 大。预处理组合数时间复杂度 O ( n 2 ) ,求 f [ K ] 时间复杂度也是 O ( n 2 ) ,总时间复杂度 O ( n 2 ) ,比直接暴力优秀很多,但是还是Time Limit Enough,那我们考虑对公式变形:

将组合数拆开:

f [ K ] = i = K n ( i 1 ) ! ( K 1 ) ! ( i K ) ! 2 n i A [ i ]

1 ( K 1 ) ! i 无关,可以提出:

f [ K ] = 1 ( K 1 ) ! i = K n ( i 1 ) ! ( i K ) ! 2 n i A [ i ]

观察等式右边可以发现 n i + i K = n K 是一个定值,很容易想到卷积,那我们令 a [ i ] = 1 i ! b [ i ] = ( n i 1 ) ! 2 i A [ n i ] ,那么公式就变成了下面这个样子:

f [ K ] = 1 ( K 1 ) ! i = K n a [ i K ] b [ n i ]

可以看到 i = K n a [ i K ] b [ n i ] 已经是一个卷积形式,看不出来也没关系,再改一下样子就能看出来了:

f [ K ] = 1 ( K 1 ) ! i + j = n K a [ i ] b [ j ]

所以我们只要求出 a , b 的卷积 X ,就能得到 f [ K ] ,要注意 f [ K ] = 1 ( K 1 ) ! X [ n K ] ,而不是 f [ K ] = 1 ( K 1 ) ! X [ K ]

预处理 a , b 时间复杂度 O ( n ) ,涉及到取模需要用 N T T ,时间复杂度 O ( n log n ) ,总时间复杂度 O ( n log n ) ,到此问题全部解决。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=262150,mod=998244353;
int T,n,size,need,ans,a[N],b[N],A[N],fac[N],Pow[N],inv[N],ni[N];
bool cmp(int a,int b)
{
    return a>b;
}
int power(int v,int p)
{
    int ans=1;
    for(;p;p>>=1,v=1ll*v*v%mod)
     if(p&1)ans=1ll*ans*v%mod;
    return ans;
}
void change(int *s,int len)
{
    for(int i=1,j=len/2;i<len-1;i++)
    {
        if(i<j)swap(s[i],s[j]);int k=len/2;
        while(j>=k)j-=k,k/=2;j+=k;
    }
}
void NTT(int *s,int n,int t)
{
    change(s,n);
    for(int i=2;i<=n;i<<=1)
    {
        int half=i/2,wn=power(3,(mod-1)/i);
        if(t==-1)wn=power(wn,mod-2);
        for(int j=0;j<n;j+=i)
        {
            int w=1,u,v;
            for(int k=j;k<j+half;k++,w=1ll*w*wn%mod)
            {
                u=s[k];v=1ll*w*s[k+half]%mod;
                s[k]=(u+v)%mod;if(s[k]<0)s[k]+=mod;
                s[k+half]=(u-v)%mod;if(s[k+half]<0)s[k+half]+=mod;
            }
        }
    }
    if(t==-1)
    {
        for(int i=0;i<n;i++)
         s[i]=1ll*s[i]*need%mod;
    }
}
int main()
{
    scanf("%d",&T);
    fac[0]=Pow[0]=inv[0]=1;
    fac[1]=inv[1]=ni[1]=1;Pow[1]=2;
    for(int i=2;i<=100000;i++)
    {
        fac[i]=1ll*fac[i-1]*i%mod;
        Pow[i]=1ll*Pow[i-1]*2%mod;
        ni[i]=1ll*(mod-mod/i)*ni[mod%i]%mod;
        inv[i]=1ll*inv[i-1]*ni[i]%mod;
    }
    while(T--)
    {
        scanf("%d",&n);
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));size=1;
        for(int i=1;i<=n;i++)scanf("%d",&A[i]);
        sort(A+1,A+n+1,cmp);
        for(int i=0;i<n;i++)\\根据公式可以发现只需要用到0~n-1
        {
            a[i]=inv[i]%mod;
            b[i]=1ll*A[n-i]*fac[n-i-1]%mod*Pow[i]%mod;
        }
        while(size<=n+n)size<<=1;need=power(size,mod-2);
        NTT(a,size,1);NTT(b,size,1);
        for(int i=0;i<=size;i++)a[i]=1ll*a[i]*b[i]%mod;
        NTT(a,size,-1);
        for(int i=1;i<=n;i++)
        {
            (ans+=1ll*a[n-i]*inv[i-1]%mod)%=mod;
            printf("%d ",ans);
        }
        printf("\n");ans=0;
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/hdxrie/article/details/80961416