【JZOJ3316】非回文数字

description

如果一个字符串从后往前读与从前往后读一致,我们则称之为回文字符串。当一个数字不包含长度大于1的子回文数字时称为非回文数字。例如,16276是非回文数字,但17276不是,因为它包含回文数字727。

你的任务是在一个给定的范围内计算非回文数字的总数。


analysis

  • 平生最怂的数位\(DP\),询问自然拆开两段做

  • \(f[i][j][k][0/1]\)表示做到第\(i\)位、第\(i\)位数字为\(j\)、第\(i-1\)位数字为\(k\)、是否刚好顶上界的合法方案数

  • 转移自然就是\(f[i][j][k][0]+=f[i-1][k][num][0/1]\),注意顶到或不到上界的转移不同

  • 可以知道这样前导\(0\)很难搞,可以换个方法

  • \(g[i][j][k]\)表示做到第\(i\)位、第\(i\)位数字为\(j\)、第\(i-1\)位数字为\(k\)的方案数,另外单独转移

  • 也就是说位数不足的一定比原数要小,不需理会上界,相当于这些数都有前导\(0\),都算过了

  • 即钦定\(f\)\(g\)第一位均不为\(0\)开始\(DP\),最后统计答案把\(f\)\(g[1..位数-1]\)加一起就好了


code

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<math.h>
#define ll long long
#define fo(i,a,b) for (ll i=a;i<=b;++i)
#define fd(i,a,b) for (ll i=a;i>=b;--i)

using namespace std;

ll f[20][10][10][2];
ll g[20][10][10];
ll a[20];
ll l,r,mx;

inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while (ch<'0' || '9'<ch){if (ch=='-')f=-1;ch=getchar();}
    while ('0'<=ch && ch<='9')x=x*10+ch-'0',ch=getchar();
    return x*f;
}
inline ll get(ll x)
{
    if (x<10)return x;
    mx=floor(log10(x)+1);
    fd(i,mx,1)a[i]=x%10,x/=10;
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    f[1][a[1]][0][1]=1;
    fo(i,1,a[1]-1)f[1][i][0][0]=1;
    fo(i,1,9)g[1][i][0]=1;
    fo(i,2,mx)
    fo(j,0,9)//第i位的数字
    fo(k,0,9)//第i-1位的数字
    {
        if (j==k)continue;
        fo(num,0,9)//第i-2位的数字
        {
            if (i>2 && num==j)continue;
            g[i][j][k]+=g[i-1][k][num];
            if (j==a[i] && k==a[i-1])
            {
                f[i][j][k][0]+=f[i-1][k][num][0],
                f[i][j][k][1]+=f[i-1][k][num][1];
            }
            else
            {
                if (k<a[i-1] || (k==a[i-1] && j<a[i]))
                    f[i][j][k][0]+=f[i-1][k][num][0]+f[i-1][k][num][1];
                else f[i][j][k][0]+=f[i-1][k][num][0];
            }
        }
    }
    ll ans=0;
    fo(i,1,mx-1)fo(j,0,9)fo(k,0,9)ans+=g[i][j][k];
    fo(i,0,9)fo(j,0,9)ans+=f[mx][i][j][0]+f[mx][i][j][1];
    return ans;
}
int main()
{
    l=read(),r=read();
    printf("%lld\n",get(r)-get(l-1));
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/horizonwd/p/11134754.html