[CF286E] Ladies' shop

Description

给出 \(n\)\(\leq m\) 且不同的数 \(a_1,\dots,a_n\),现在要求从这 \(n\) 个数中选出最少的数字,满足这 \(n\) 个数字都可以由选出的数字组合成(就是做一个完全背包能做出来),并且任意组合出来的数字,只要不超过 \(m\),就必须让这个数字在给出的 \(n\) 个数中。问是否可行,如果可行,请求出最少选多少数字。 \(n,m\leq 10^6\)

Sol

先判断是否可行,再看哪些数可以省略。

求出 \(a\) 数组的生成函数,即构造多项式 \(F(x)=\sum f_i\cdot x^i\)\(f_i\)\(1\) 当且仅当 \(a\) 数组中出现 \(a_*=i\)

然后求出 \(G(x)=F^2(i)=\sum g_i\cdot x^i\)。如果 \(g_i>0\) 那就说明给出的这 \(n\) 个数可以合成 \(i\)

于是就得到了从原来的 \(n\) 个数中拿出 \(0\sim 2\) 个的结果。

然而最多拿出 \(m\) 个。

所以还要继续,用快速幂求得 \(f^m\)。如果多项式快速幂的话,复杂度 \(O(n\log^2n)\),用多项式ln+多项式exp求的话,复杂度 \(O(n\log n)\)。但是多项式exp常数太大了!

事实上是有只做 \(1\) 次FFT的方法的。

显然如果 \(f_i>0\) 的话,\(g_i>0\)

那我们只要保证满足 \(f_i=0,g_i>0,i\leq m\)\(i\) 不存在就好了。

如果第一轮不存在这些不合法的,那接下来肯定也不存在。感性理解一下这就相当于构成了一个封闭的集合。

所以只做 \(1\) 次FFT就行了。

然后考虑一下哪些数可以省略

如果一个数 \(i\) 可以被其他数表示出来,那 \(g_i\) 一定 \(>2\)。所以 \(g_i=2\)\(i\) 就是必选的。

时间复杂度 \(O(n\log n)\)

Sol

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e6+5;
const int mod=998244353;

int lim,rev[N];
int n,m,a[N],b[N];

int ksm(int a,int b=mod-2,int ans=1){
    while(b){
        if(b&1) ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;b>>=1;
    } return ans;
}

int getint(){
    int X=0,w=0;char ch=getchar();
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

void ntt(int *f,int g){
    for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int tmp=ksm(g,(mod-1)/(mid<<1));
        for(int R=mid<<1,j=0;j<lim;j+=R){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
                int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
            }
        }
    } if(g>3) 
        for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}

signed main(){
    n=getint(),m=getint();
    for(int i=1;i<=n;i++){
        int x=getint();
        a[x]=b[x]=1;
    } 
    lim=1;while(lim<=m+m) lim<<=1;
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
    a[0]=1; ntt(a,3);
    for(int i=0;i<lim;i++) a[i]=1ll*a[i]*a[i]%mod;
    ntt(a,(mod+1)/3);
    for(int i=1;i<=m;i++)
        if(a[i] and !b[i]) return printf("NO"),0;
    puts("YES"); int tot=0;
    for(int i=1;i<=m;i++)
        if(a[i]==2) tot++;
    printf("%d\n",tot);
    for(int i=1;i<=m;i++)
        if(a[i]==2) printf("%d ",i);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/YoungNeal/p/10360660.html