ARC093F Dark Horse 【容斥,状压dp】

题目链接:gfoj

神仙计数题。

可以转化为求\(p_1,p_2,\ldots,p_{2^n}\),使得\(b_i=\min\limits_{j=2^i+1}^{2^{i+1}}p_j\)都不属于\(a_i\)

日常容斥。设\(f(S)\)表示\(i\in S\Rightarrow b_i\in A\)的答案,则答案就是\(ans=\sum_S(-1)^{|S|}f(S)\)

\(f(S)\)使用状压dp。设\(f[i][S]\)表示将\(a_i\)从大到小排序,\(b_i\)\(a\)中出现的下标\(i\)组成的集合\(S\),方案数是多少。

初值\(f[0][0] = 1\)

如果\(a_{i+1}\)不在\(b_i\)中出现,则\(f[i+1][S]\leftarrow f[i][S]\)

如果\(a_{i+1}\)\(b_i\)中出现,枚举\(a_{i+1}=b_k\),那么我们要在\(2^n-S-a_i\)个数中选出\(2^k-1\)个数被\(a_{i+1}\)打掉,组成排列\((2^k)!\)种方案,那么\(f[i+1][S|2^k]\leftarrow f[i][S]\times \dbinom{2^n-S-a_i}{2^k-1}\times (2^k)!\)

然后你发现我们并没有把不在\(b_i\)中出现的\(S\)这些数没有乘上,所以\(f(S)=f[m][S]\times S!\)。然后抄个柿子上去,时间复杂度\(O(nm2^n)\)


code

扫描二维码关注公众号,回复: 7846035 查看本文章
#include<bits/stdc++.h>
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 16, mod = 1e9 + 7;
int n, m, a[N], f[N + 1][1 << N], fac[1 << N], inv[1 << N];
bool siz[1 << N];
inline void upd(int &a, int b){a += b; if(a >= mod) a -= mod;}
inline int kasumi(int a, int b){
    int res = 1;
    while(b){
        if(b & 1) res = (LL) res * a % mod;
        a = (LL) a * a % mod; b >>= 1;
    }
    return res;
}
inline void init(int m){
    fac[0] = 1;
    for(Rint i = 1;i <= m;i ++) fac[i] = (LL) fac[i - 1] * i % mod;
    inv[m] = kasumi(fac[m], mod - 2);
    for(Rint i = m;i;i --) inv[i - 1] = (LL) inv[i] * i % mod;
    siz[0] = 0;
    for(Rint i = 1;i <= m;i ++) siz[i] = !siz[i ^ (i & -i)];
}
inline int C(int n, int m){
    if(n < 0 || m < 0 || n < m) return 0;
    return (LL) fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int main(){
    scanf("%d%d", &n, &m);
    for(Rint i = 0;i < m;i ++) scanf("%d", a + i);
    sort(a, a + m, greater<int>()); init((1 << n) - 1);
    f[0][0] = 1;
    for(Rint i = 0;i < m;i ++)
        for(Rint S = 0;S < (1 << n);S ++){
            upd(f[i + 1][S], f[i][S]);
            int t = (1 << n) - S - a[i];
            for(Rint k = 0;k < n;k ++) if(!((S >> k) & 1))
                upd(f[i + 1][S | (1 << k)], (LL) f[i][S] * C(t, (1 << k) - 1) % mod * fac[1 << k] % mod);
        }
    int ans = 0;
    for(Rint S = 0;S < (1 << n);S ++)
        if(siz[S]) upd(ans, mod - (LL) f[m][S] * fac[(1 << n) - S - 1] % mod);
        else upd(ans, (LL) f[m][S] * fac[(1 << n) - S - 1] % mod);
    printf("%d", (LL) ans * (1 << n) % mod);
}

猜你喜欢

转载自www.cnblogs.com/AThousandMoons/p/11853376.html