2020 CCPC Wannafly Winter Camp Day2 B.萨博的方程式(数位DP)

题意:

萨博有个方程式:

x 1   x o r   x 2     x o r   x n = k ( x i [ 0 , m i ] ) x_1\ xor\ x_2\ \dots\ xor \ x_n=k\quad (x_i\in[0,m_i])
n 50 ,   k , m i < 2 31 n\le50,\ k,m_i<2^{31}
求有多少组x在满足条件的情况下使得等式成立,答案对 1 e 9 + 7 1e9+7 取模。

题解:

以下为邦邦老师的ppt,这一部分讲的挺清晰的:
在这里插入图片描述
dp状态的解释也很清晰:
在这里插入图片描述
这样的做法复杂度是 O ( T n 2 l o g ( m ) ) O(T*n^2log(m)) 的。
我自己做的时候一开始不知道得到dp这个dp状态之后如何得到这一位不全选1的解对答案的贡献,后来想明白了:假设这一位有 x x 1 1 ,那么 F ( x , j ) F(x,j) 对答案的贡献是 F ( x , j ) / 2 p o s F(x,j)/2^{pos} ,pos指的是最高位的位数。能做贡献的前提是j的奇偶性和k的这一位是一样的。
为什么呢。因为 F ( x , j ) F(x,j) 相当于是总的方案数,然后我们因为有一个k的限制,所以可以让其中一个高位本可选1但是选了0的数字去和其他的凑,也就是说其他的定了它也就定了,不能乱动。这和它原本可以选择 2 p o s 2^{pos} 种方案相比,变成了只有一种选择。
想清楚这一点之后我们发现,实际上不用记录选择了几个1,而只要记录选择了奇数个1还是偶数个1,是否为全选即可得到贡献,这样优化了一层记录选择1的个数的循环。 d p ( i , 0 / 1 , 0 / 1 ) dp(i,0/1,0/1) 表示考虑前i个1,用了偶/奇数个1,未全选/全选了1的方案数。
这样复杂度可以优化为 O ( T n l o g ( m ) ) O(T*nlog(m)) 。这样数据范围的 n n 就可以出到 1 e 5 1e5 啦~

未优化的代码:

#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][55];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
    ll res = 1;
    while(b){
        if(b&1) res = res*a%mod;
        a = a*a%mod;
        b >>= 1;
    }return res;
}
ll sol(int pos){
    //cout<<"pos:"<<pos<<endl;
    if(pos < 0) return 1;
    ll res = 0;
    memset(dp, 0, sizeof dp);
    dp[0][0] = 1;
    int cur = 0;
    for(int i = 1; i <= n; ++i){
        if(m[i]>>pos&1){
            cur++;
            dp[cur][0] = dp[cur-1][0]*(1LL<<pos)%mod;
            for(int j = 1; j <= cur; ++j){
                dp[cur][j] = (dp[cur-1][j]*(1LL<<pos)%mod + dp[cur-1][j-1]*(m[i]-(1LL<<pos)+1)%mod)%mod;
            }
        }else{
            for(int j = 0; j <= cur; ++j) dp[cur][j] = dp[cur][j]*(m[i]+1)%mod;
        }
    }
    ll inv = qm(1LL<<pos, mod-2);
    for(int i = (k>>pos&1); i < cur; i+=2){
        res = (res + dp[cur][i]*inv)%mod;
    }
    //cout<<"res:"<<res<<endl;
    if((cur&1) == (k>>pos&1)){
        for(int i = 1; i <= n; ++i){
            if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
        }
        return (res + sol(pos-1))%mod;
    }else return res;
}
int main()
{
    while(scanf("%d%lld", &n, &k)!=EOF){
        for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
        ll ans = sol(31);
        ans = (ans + mod)%mod;
        cout<<ans<<endl;
    }
}

优化后的代码

#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][2][2];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
    ll res = 1;
    while(b){
        if(b&1) res = res*a%mod;
        a = a*a%mod;
        b >>= 1;
    }return res;
}
ll sol(int pos){
    //cout<<"pos:"<<pos<<endl;
    if(pos < 0) return 1;
    ll res = 0;
    memset(dp, 0, sizeof dp);
    dp[0][0][1] = 1;
    int cur = 0;
    for(int i = 1; i <= n; ++i){
        if(m[i]>>pos&1){
            cur++;
            dp[cur][0][0] =( (m[i]-(1LL<<pos)+1)*dp[cur-1][1][0]%mod + (1LL<<pos)*(dp[cur-1][0][0]+dp[cur-1][0][1])%mod )%mod;
            dp[cur][0][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][1][1]%mod;
            dp[cur][1][0] = ( (m[i]-(1LL<<pos)+1)*dp[cur-1][0][0]%mod + (1LL<<pos)*(dp[cur-1][1][0] + dp[cur-1][1][1])%mod )%mod;
            dp[cur][1][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][0][1]%mod;
        }else{
            dp[cur][0][0] = (dp[cur][0][0]*(m[i]+1))%mod;
            dp[cur][1][0] = (dp[cur][1][0]*(m[i]+1))%mod;
            dp[cur][0][1] = (dp[cur][0][1]*(m[i]+1))%mod;
            dp[cur][1][1] = (dp[cur][1][1]*(m[i]+1))%mod;
        }
    }
    ll inv = qm(1LL<<pos, mod-2);
    res = dp[cur][k>>pos&1][0]*inv%mod;
    if((cur&1) == (k>>pos&1)){
        for(int i = 1; i <= n; ++i){
            if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
        }
        return (res + sol(pos-1))%mod;
    }else return res;

}
int main()
{
    while(scanf("%d%lld", &n, &k)!=EOF){
        for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
        ll ans = sol(31);
        ans = (ans + mod)%mod;
        cout<<ans<<endl;
    }
}

发布了102 篇原创文章 · 获赞 30 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_43202683/article/details/104062852