CodeForces 1051E Vasya and Big Integers(dp + 树状数组 + 字符串哈希/后缀数组)

版权声明:本文为博主原创文章,转载请著名出处 http://blog.csdn.net/u013534123 https://blog.csdn.net/u013534123/article/details/82828390

大致题意:给你一个很大的数字,然后你可以把这个数字拆分成为任意多个部分,要求每一个部分的数字大小要在一个区间内,问有多少种拆分方式。

由于是给定数字的拆分,所以区间对于拆分的限制,仅仅是限于长度。也即如果拆分的部分的长度介于上界和下界的长度之间,那么直接查分即可。如果长度等于上界或者下界,那么需要按位比较于界限的大小。可以看到,这个过程相当于一个转移的过程,很自然而然的想到用dp去求解。令dp[i]表示剩下的数字的长度位i的方案数,有转移方程dp[i]=Σdp[j],其中的i和j满足:len1<=j-i<=len2。另外,当i-j恰好等于上下界限的时候要特殊判断。

这个dp的复杂度是O(N^2)的,对于长度最大可以到1e5的数据范围来说,不能够满足要求。但是观察这个转移方程,可以看到,一个状态i,可以转移到的状态是[i-len2,i-len1],因此我们可以考虑一次性把这个区间更新完毕。所以用树状数组来维护区间更新,每次更新的量是对应状态的方案数,通过单点查询来得到。于是,这样做时间复杂度看起来可以优化到O(NlogN)。

但是,我们之前说过了,当之前状态的长度于当前长度之差恰好等于边界的时候,需要特殊判断相应数字与对应边界的大小关系。而这个判断,最朴素的办法就是暴力判断,从第一个开始比较。所以说对于这样的比较方式,我完全可以构造数据使得需要比较到一个比较长的长度才能够知道这两个的大小关系,这样就很容易被卡超时。所以我们考虑用字符串哈希的方法来解决这个问题。字符串哈希二分比较两个字符串,两个子串最长的公共前缀。但是由于字符串哈希本身的常数比较大,所以如果只是用字符串哈希比较也会超时。所以说我们折衷一下,当直接比较了40位还没有出结果,那么我就用字符串哈希来求,否则就直接暴力比较解决问题。这样,时间复杂度应该就是O(NlogN)。

最后我们发现,字符串哈希的过程其实就是求一下两个字符串的LCP,然后比较LCP的下一位就可以知道两个字符串的大小关系。所以如果要更加稳定的复杂度,我们可以考虑用后缀数组,求出height数组,然后每一次求一下两个字符串LCP即可。但是论代码复杂度来说可能还是字符串哈希简单一些,具体见代码:

#include<bits/stdc++.h>
#define mod 998244353
#define LL long long
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define INF 0x3f3f3f3f
#define sf(x) scanf("%d",&x)
#define sc(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define clr(x,n) memset(x,0,sizeof(x[0])*(n+5))
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
using namespace std;

const int N = 1000010;

int c[N],len,Len[2];
char s[N],t[2][N];

inline void update(int x,int y)
{
    x++;
    for(int i=x;i<N;i+=i&-i)
        c[i]=(c[i]+y)%mod;
}

inline int getsum(int x)
{
    int res=0; x++;
    for(int i=x;i;i-=i&-i)
        res=(res+c[i])%mod;
    return res;
}

int Hs[3][N]; LL P[N];
const int MOD=100001623;

void init()
{
    P[0]=1;
    for(int i=1;i<N;i++)
        P[i]=P[i-1]*111%MOD;
}

void ins(char *s,int op)
{
    Hs[op][0]=s[0]-'0'+1;
    for(int i=1;s[i];i++)
        Hs[op][i]=(Hs[op][i-1]*111LL%MOD+s[i]-'0'+1)%MOD;
}

inline int Hash(int l,int r,int op)
{
    int tmp1=Hs[op][r];
    if (!l) return tmp1;
    tmp1=(tmp1-Hs[op][l-1]*P[r-l+1]%MOD+MOD)%MOD;
    return tmp1;
}

inline int cmp(char *s,char *t)
{
    for(int i=0;t[i];i++)
    {
        if (s[i]>t[i]) return 1;
        else if (s[i]<t[i])return -1;
        if (i>40) return -3;
    }
    return 0;
}

inline int cmp(int b,int op)
{
    int tmp1,tmp2;
    int tt=cmp(s+b,t[op]);
    if (tt!=-3) return tt;
    int l=0,r=Len[op]-1,mid,res=-1;
    while(l<=r)
    {
        mid=(l+r)>>1;
        tmp1=Hash(b,b+mid,0);
        tmp2=Hash(0,mid,op+1);
        if (tmp1==tmp2) res=mid,l=mid+1;
                           else r=mid-1;
    }
    if (res==Len[op]-1) return 0;
    return s[b+res+1]>t[op][res+1]?1:-1;
}

int main()
{
    init();
    scanf("%s%s%s",s,t[0],t[1]);
    len=strlen(s); update(len,1);
    ins(s,0); ins(t[0],1); ins(t[1],2);
    Len[0]=strlen(t[0]); Len[1]=strlen(t[1]);
    for(int k=len,b=0;k&&k>=Len[0];k--,b++)
    {
        int delta=getsum(k);
        if (!delta) continue;
        if (s[b]=='0')
        {
            if (t[0][0]=='0')
            {
                update(k-Len[0],delta);
                update(k-Len[0]+1,-delta);
            }
            continue;
        }
        if (Len[0]==Len[1]&&cmp(b,1)>0) continue;
        if (cmp(b,0)>=0)
        {
            update(k-Len[0],delta);
            update(k-Len[0]+1,-delta);
        }
        int L=max(k-Len[1]+1,0);
        int R=k-Len[0]-1;
        if (L<=R&&R>=0)
        {
            update(L,delta);
            update(R+1,-delta);
        }
        if (Len[1]>k||Len[0]==Len[1]) continue;
        if (cmp(b,1)<=0)
        {
            update(k-Len[1],delta);
            update(k-Len[1]+1,-delta);
        }
    }
    printf("%d\n",(getsum(0)+mod)%mod);

    return 0;
}

猜你喜欢

转载自blog.csdn.net/u013534123/article/details/82828390