POJ3252 Round Numbers 【数位dp】

题目链接

POJ3252

题解

为什么每次写出数位dp都如此兴奋?
因为数位dp太苟了
因为我太弱了

\(f[i][0|1][cnt1][cnt0]\)表示到二进制第\(i\)位,之前是否达到上界,前面已经有\(cnt1\)\(1\)\(cnt0\)\(0\)时的方案数
显然当\(cnt1 = 0\)时就不存在任何前导数字了

然后就记忆化搜索 分类讨论各种转移
【为什么我写得好麻烦QAQ是不是我姿势不对】

#include<iostream>
#include<cstdio>
#include<cmath>
#include<map>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define cls(s) memset(s,0,sizeof(s))
#define mp(a,b) make_pair<int,int>(a,b)
#define cp pair<int,int>
using namespace std;
const int maxn = 100005,maxm = 100005,INF = 1000000000;
inline int read(){
    int out = 0,flag = 1; char c = getchar();
    while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
    while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
    return out * flag;
}
int bit[maxn],vis[33][2][33][33];
LL f[33][2][33][33];
LL cal(int n,int lim,int cnt1,int cnt0){
    if (!n) return cnt0 >= cnt1;
    if (vis[n][lim][cnt1][cnt0]) return f[n][lim][cnt1][cnt0];
    vis[n][lim][cnt1][cnt0] = true;
    LL& re = f[n][lim][cnt1][cnt0];
    if (!lim && !cnt1) return re = cal(n - 1,0,1,0) + cal(n - 1,0,0,0);
    else if (!lim){
        int tot = cnt1 + cnt0 + n,least = (tot & 1) ? (tot >> 1) + 1 : (tot >> 1);
        least = least - cnt0;
        if (least <= 0) return re = (1 << n);
        LL C = 1;
        for (int i = 1; i <= n; i++){
            C = C * (n - i + 1) / i;
            if (i >= least) re += C;
        }
        return re;
    }
    else if (!cnt1){
        if (!bit[n]) return re = cal(n - 1,1,0,0);
        else return re = cal(n - 1,1,1,0) + cal(n - 1,0,0,0);
    }
    else {
        if (!bit[n]) return re = cal(n - 1,1,cnt1,cnt0 + 1);
        else return re = cal(n - 1,1,cnt1 + 1,cnt0) + cal(n - 1,0,cnt1,cnt0 + 1);
    }
}
LL solve(int x){
    cls(f); cls(vis);
    int n = 0,tmp = x;
    while (tmp) bit[++n] = (tmp & 1),tmp >>= 1;
    return cal(n,1,0,0);
}
int main(){
    int a = read(),b = read();
    if (a > b) swap(a,b);
    printf("%lld\n",solve(b) - solve(a - 1));
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Mychael/p/9022863.html
今日推荐