#3123. 「CTS2019 | CTSC2019」重复

简单dp

我们考虑一下能不能造个自动机识别所有存在一段区间字典序比给定的模式串少的串

这是可行的,你把kmp自动机稍微魔改一下就可以了

具体来讲在一开始连出一条链的时候加一些边指向终止节点,然后把这个东西拉过去跑ac自动机的bfs

注意bfs的时候不把终止节点入队就可以了

然后配合上thusc的一些科技以及高妙的卡常技巧就能得到一个70分的好成绩

接下来是正解

事实上,任何一个不合法的字符串都可以和自动机上一条起点和终点相同,并且长度为m的路径一一对应

直接在自动机上dp可以得到\[O(N^2m)\]的大暴力

优化需要一点性质,具体来讲是这样的

我们使用归纳法证明这样一个命题

一个点除了指向根的边之外,最多只有一条不指向终止节点的出边

反证一下,假如有别的出边,那么这个边一定不是trie链上原来的边,而是fail树上节点的出边

容易发现小的出边会被大的出边ban掉,所以就只剩下一条出边了

总而言之,此时我们就可以将环分成两类计数,不经过根的,和经过根的

前者可以暴力,后者是个简单dp

于是复杂度\[O(nm)\]

#include<cstdio>
#include<algorithm>
#include<map>
#include<queue>
using namespace std;const int N=1e5+10;typedef long long ll;const ll mod=998244353;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
ll dp[2010][2010];
int n;int m;char mde[N];ll bk[N];int fr[N];
struct automaton
{
    int mp[N][30];int fil[N];
    queue <int> q;int ct;
    inline void build(char * mde,int len)
    {
        ct=1;int ed=n+2;
        for(int i=1;i<=n;i++)
        {
            int id=mde[i]-'a'+1;
            mp[i][id]=++ct;
        //  printf("link %d %d %d\n",i,id,ct);
            for(int j=1;j<id;j++)
                mp[i][j]=ed;
        }
        /*for(int i=1;i<=ct;i++)
        {
            for(int j=1;j<=26;j++)
                printf("%d ",mp[i][j]);
            printf("\n");
        }*/
        for(int j=1;j<=26;j++)
        {
            if(mp[1][j]==0)
            {
                mp[1][j]=1;
            }
            else if(mp[1][j]!=ed)
            {
                fil[mp[1][j]]=1;q.push(mp[1][j]);
            }
        }
        while(!q.empty())
        {
            int nw=q.front();q.pop();
            for(int j=1;j<=26;j++)
                if(mp[nw][j]==0)
                {
                    mp[nw][j]=mp[fil[nw]][j];
                }
                else if(mp[nw][j]!=ed)
                {
                    fil[mp[nw][j]]=mp[fil[nw]][j];
                    q.push(mp[nw][j]);
                }
        }
        for(int i=1;i<=26;i++)
            mp[ed][i]=ed;
        /*for(int i=1;i<=ct+1;i++)
        {
            for(int j=1;j<=26;j++)
                printf("%d ",mp[i][j]);printf("\n");
        }*/
    }
    inline void prit()
    {
        int ed=n+2;
        for(int i=1;i<=ct;i++)
        {
            for(int j=1;j<=26;j++)
                if(mp[i][j]==1)
                {
                    //printf("e ");
                    bk[i]++;
                }
            for(int j=1;j<=26;j++)
                if(mp[i][j]!=1&&mp[i][j]!=ed)
                {
                    //printf("%d ",mp[i][j]);
                    //fr[i]++;
                    fr[i]=mp[i][j];
                }
        //  printf("\n");
        }
    }
}tr;
# define md(x) (x=(x>=mod)?x-mod:x)
inline ll dfs(int pos,int tim,int len)
{
    if(len==m)return pos==tim;
    ll ans=0;
    for(int i=1;i<=26;i++)
        (ans+=dfs(tr.mp[pos][i],tim,len+1))%=mod;
    return ans;
}
int main()
{
    scanf("%d",&m);
    scanf("%s",mde+1);
    while(mde[n+1]!='\0')n++;
    tr.build(mde,n);
    tr.prit();
    dp[0][1]=1;
    //for(int j=1;j<=tr.ct;j++)
//      printf("%lld ",bk[j]);printf("\n");
//  for(int j=1;j<=tr.ct;j++)
//      printf("%d ",fr[j]);printf("\n");
    for(int i=0;i<m;i++)
    {
        ll * p=dp[i+1];ll* q=dp[i];
        for(int j=1;j<=tr.ct;j++)
        {
            if(fr[j])(p[fr[j]]+=q[j])%=mod,
            (p[1]+=q[j]*bk[j])%=mod;
        }
    }
    /*for(int i=0;i<=m;i++)
    {
        for(int j=1;j<=tr.ct;j++)
            printf("%lld ",dp[i][j]);
        printf("\n");
    }*/
    ll ans=0;//int tot=0;
    //printf("1 ret=%lld\n",dp[m][1]);
    for(int i=2;i<=n+1;i++)
    {
        int st=i;
        ll ret=0;int cnt=0;
        for(int k=1;k<=m;k++)
        {
        //  printf("%d->",st);
            (ret+=bk[st]*dp[m-k][i])%=mod;
            if(fr[st]!=0)st=fr[st];else break;
            cnt++;
        }
        //if(cnt==m&&st==i)printf("hit!\n");
        (ret+=(cnt==m&&st==i))%=mod;
        (ans+=ret)%=mod;
        //printf("%d ret=%lld\n",i,ret);
    }
    
    //printf("%lld %lld\n",(po(26,m)+mod-11873600)%mod,ans);
    (ans+=dp[m][1])%=mod;
    
    printf("%lld",(po(26,m)+mod-ans)%mod);
    //printf("%lld\n",dfs(1,0));
//  ll bruret=0;
//  for(int i=1;i<=tr.ct;i++)
//  {
//      ll tmp=dfs(i,i,0);
//      printf("%d bret=%lld\n",i,tmp);
//      (bruret+=tmp)%=mod;
//  }
//  printf("%lld %lld\n",bruret,(po(26,m)+mod-bruret)%mod);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/sweetphoenix/p/11072180.html