Wannafly Camp 2020 Day 2B 萨博的方程式 - 数位dp

给定 \(n\) 个数 \(m_i\),求 \((x_1,x_2,...,x_n)\) 的个数,使得 \(x_1 \ xor\ x_2\ xor\ ...\ xor\ x_n = k\),且 $0 \leq x_i \leq m_i$

Solution

从最高位开始看起,毫无疑问,如果 \(m_i\) 的某一位是 $0$,那么 \(x_i\) 的这一位只能填 $0$,所以只有那些 \(m_i\) 最高位是 $1$ 的才具有选择权。

考虑从最高位数起,哪一位 \(pos\) 开始,存在一个 \(i\) 使得 \(x_i \neq m_i\),很显然这个 \(pos\) 是有范围的,它一定是从最高位开始往下的一段区间,因为如果某一位上,\(m_i\) 这一位的异或和 \(\neq k\) 的这一位,更低的位就可以扔掉了。

假设从第 \(pos\) 位开始,存在一个 \(i\) 使得 \(x_i \neq m_i\),我们需要统计所有满足这个条件的答案,不妨把这个部分的贡献称作第 \(pos\) 位的贡献。

\(f[i][j]\) 表示前 \(i\)\(m\) 的当前位是 $1$ 的数中,选择了 \(j\)\(x\) 的当前位是 $1$,\(i-j\) 个是 $0$ 的方案数,那么

\[ f[i][j]=f[i-1][j-1]\cdot (x_i \ mod \ 2^{pos} + 1) + f[i-1][j]\cdot 2^{pos} \]

考虑如何统计第 \(pos\) 位的贡献,假设这位 $1$ 的个数为 \(cnt\),那么 \(f[cnt][j]\) 答案的贡献是 \(f[cnt][j]/2^{pos}\),当 \(j\)\(k\) 该位的奇偶性相同时产生。

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 105;
const int mod = 1e9+7;
inline void exgcd(int a,int b,int &x,int &y) {
    if(!b) {
        x=1,y=0;
        return;
    }
    exgcd(b,a%b,x,y);
    int t=x;
    x=y,y=t-(a/b)*y;
}

inline int inv(int a,int b) {
    int x,y;
    return exgcd(a,b,x,y),(x%b+b)%b;
}

int n,k,m[N],f[N][N];

int solve(int pos) {
    if(pos<0) return 1; //!
    int ret=0,cnt=0;
    memset(f,0,sizeof f);
    f[0][0]=1;
    for(int i=1;i<=n;i++) {
        if((m[i]>>pos)&1) {
            ++cnt;
            f[cnt][0]=f[cnt-1][0]*(1<<pos)%mod; //!
            for(int j=1;j<=cnt;j++) {
                f[cnt][j]=f[cnt-1][j-1]*(m[i]%(1<<pos)+1)
                        +f[cnt-1][j]*(1<<pos);
                f[cnt][j]%=mod;
            }
        }
        else {
            for(int j=0;j<=cnt;j++) {
                f[cnt][j]=f[cnt][j]*(m[i]+1); //!
                f[cnt][j]%=mod;
            }
        }
    }
    int r=inv(1<<pos,mod);
    for(int j=(k>>pos&1);j<cnt;j+=2) {
        ret+=f[cnt][j]*r;
        ret%=mod;
    }
    if((cnt&1) == ((k>>pos)&1)) {
        for(int i=1;i<=n;i++) {
            if(m[i]>>pos&1) m[i]^=(1<<pos);
        }
        return (solve(pos-1) + ret)%mod;
    }
    else return ret;
}

signed main() {
    ios::sync_with_stdio(false);
    while(cin>>n>>k) {
        for(int i=1;i<=n;i++) cin>>m[i];
        cout<<solve(31)<<endl;
    }
}

猜你喜欢

转载自www.cnblogs.com/mollnn/p/12336422.html