AC自动机 + DP or 矩阵 题型总结

AC自动机首先是一个多模式匹配算法,简单来说就是有多个模式串,然后查询的结果就是被查询串中有多少个模式串。模式串之间可以互相重叠。
实际上,AC自动机就是在一颗Trie树上添加了一fail指针,这个指针和KMP中的Next数组的作用是一样的:代表着失配后应该转移的位置。如果fail指针指向了root,那么说明在trie树中的前缀没有出现在被查询串的后缀中了。

我们在AC自动机上j进行DP时,其实利用的是AC自动机所创建的一个状态图,我们把每trie树上的节点之间的连接和fail指针当成边,那么状态就可以在这个图上进行转移,我们可以利用这个给每个节点创建一个状态矩阵,代表有多少个节点能到到达,这个在我们求一些问题比如不包含某些串的串有多少个,就很方便了。

在AC自动机上进行DP,一般来说要进行状态压缩,在设计转移时,主要考虑的时AC自动机上的连接状态,即当前状态的下一个可能转移到的状态一定是在AC自动机上进行的,有了这一个指导思想,我们可以按照题目需要将DP状态改造成我们想要的样子。

除了AC自动机和DP结合,AC自动机还可以一些矩阵题结合。比如一些题目要求你求出给定长度然后不能出现某些串的串有多少个的时候,我们就要用到矩阵了。我们利用AC自动机创建走一步的一个转移矩阵,然后根据图论上点之间转移的理论,就能得出长度为n的转移矩阵。

在改造AC自动机时,我们应该注意AC自动机插入的串的形式和作用,根据他们的不同来执行对应的变化。同时还要注意标记的传递,在构造fail指针的时候,我们可能会将标记传给需要更新的点。下面几道例题会具体分析。

AC自动机 + 矩阵 POJ - 2778

这道题的大意是要求你得出长度为n的,不包含有害DNA的串的数量。这里就要用到我们之前说过的Tire转移图。我们设计一个矩阵M,其中 M i j 代表第i个自动机上的节点,转移到他的第i个节点有多少种方案。然后用矩阵快速幂算出来n转移的方案,求个和就行。

#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <string.h>
#include <queue>
using namespace std;
const int maxn = 111;
const int MOD=100000;
struct Mat
{
    int m[maxn][maxn],n;
    Mat(){};
    Mat(int _n)
    {
        n = _n;
        for(int i = 0; i<n; i++)
            for(int j = 0; j<n; j++)
                m[i][j] = 0;
    }
    Mat operator * (const Mat &b) const
    {
        Mat res = Mat(n);
        for(int i = 0; i<n; i++)
        {
            for(int j = 0; j<n; j++)
            {
                for(int k = 0; k<n; k++)
                {
                    res.m[i][j] = (res.m[i][j] + ((long long)m[i][k]* b.m[k][j]) % MOD) %MOD;
                }
            }
        }
        return res;
    }
};
int id[300];
struct Trie
{
    int next[maxn][26],fail[maxn];
    bool end[maxn];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 4; i++)
            next[L][i] = -1;
        end[L++] = 0;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    void insert(char buf[])
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0; i < len; i++)
        {
            if(next[now][id[buf[i]]] == -1)
                next[now][id[buf[i]]] = newnode();
            now = next[now][id[buf[i]]];
        }
        end[now] =1;
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 4; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            if(end[fail[now]] == 1)
            {
                end[now] = 1;
            }   //注意此处将fail指针上的标记下传了
            for(int i = 0; i < 4; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }
    Mat getMat()
    {
        Mat res = Mat(L);
        for(int i = 0; i<L; i++)
        {
            for(int j = 0; j <4; j++)
            {
                if(!end[next[i][j]])
                    res.m[i][next[i][j]]++;
            }
        }
        return res;
    }
} ac;
Mat pow_mod(Mat a,int k)
{
    Mat res = Mat(a.n);
    //res.print();
    Mat temp =a;
    for(int i = 0;i<res.n; i++)
    {
            res.m[i][i] = 1;
    }
    while(k)
    {
        if(k & 1)
            res = res* temp;
        temp = temp * temp;
        k >>= 1;
    }
    return res;
}
char buf[30];
int main()
{
    id['A'] = 0;
    id['T'] = 1;
    id['G'] = 2;
    id['C'] = 3;
    int n,m;
    while(scanf("%d%d",&m,&n)!= EOF)
    {
        ac.init();
        for(int i = 0;i<m; i++)
        {
            scanf("%s",buf);
            ac.insert(buf);
        }
        ac.build();
        Mat mat = ac.getMat();
        //mat.print();
        mat = pow_mod(mat,n);
        //res.print();
        int ans = 0;
        for(int i = 0; i<mat.n; i++)
        {
            ans = (ans + mat.m[0][i]) % MOD;
        }
        printf("%d\n",ans);
    }
    return 0;
}

AC自动机 + DP HDU - 2825
这道题题意就是一个含有至少k个给定串,长度为n的有多少。给定串之间可以重叠。这是一个DP优化问题,首先设计状态,dp[i][j][k]代表前i个字符,当前处于自动机的节点j,使用了的给定串情况是k,k是一个状态压缩后的变量,我们在插入给定串的时候将每个串的编号先压成二进制,在向i+1,next[j][t]转移的时候,k 也相对应的转移end[next[j][k]],之后对数组统计一下就行。

#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <string.h>
#include <queue>
#include <map>
using namespace std;
const int MOD =20090717;
int dp[40][101][1<<10];
int n,m,k;
int num[1<<10];
struct Trie
{
    int next[110][26],fail[110],end[5010];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 26; i++)
            next[L][i] = -1;
        end[L++] = 0;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    void insert(char buf[],int id)
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0; i < len; i++)
        {
            if(next[now][buf[i] - 'a'] == -1)
                next[now][buf[i] - 'a'] = newnode();
            now = next[now][buf[i] - 'a'];
        }
        end[now] |= (1<<id);
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 26; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            end[now] |= end[fail[now]];
            for(int i = 0; i < 26; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }
    int solve()
    {
        for(int i = 0;i<=n;i++)
        {
            for(int j = 0;j<L;j++)
            {
                for(int p = 0;p<(1<<m);p++)
                {
                    dp[i][j][p] = 0;
                }
            }
        }
        dp[0][0][0] = 1;
        for(int i = 0;i<n;i++)
        {
            for(int j = 0;j<L;j++)
            {
                for(int p = 0;p< (1<<m) ;p ++)
                {
                    if(dp[i][j][p] > 0)
                    {
                        for(int x = 0;x<26;x++)
                        {
                            int ni = i+1;
                            int nj = next[j][x];
                            int np = p | end[nj];
                            dp[ni][nj][np] = (dp[ni][nj][np] + dp[i][j][p]) % MOD;
                        }
                    }
                }
            }
        }
        int ans = 0;
        for(int p = 0;p<(1<<m);p++)
        {
            if(num[p] < k) continue;
            for(int j = 0;j<L;j++)
            {
                ans  = (ans + dp[n][j][p]) % MOD;
            }
        }
        printf("%d\n",ans);

    }
}ac;
char buf[1011];
int main()
{
    for(int i = 0;i<(1<<10);i++)
    {

        for(int j = 0;j<10;j++)
        {
            if(i & (1<<j))
            {
                num[i] ++;
            }
        }
    }
    while(scanf("%d%d%d",&n,&m,&k) != EOF)
    {
        if(n == 0 && m == 0) break;
        ac.init();
        for(int i =0;i<m;i++)
        {
            scanf("%s",&buf);
            ac.insert(buf,i);
        }

        ac.build();
        ac.solve();
    }
    return 0;
}


HDU 2296
这道题和上面的题有类似之处,只不过还要构造出最优解。
题意就是给定N个串,这N个串每个串都有一个权值,整个串的权值是给定串的权值乘以出现的次数。

这道题的状态dp[i][j] 代表构造到第i位时,选择自动机上的第j个节点获得的权值,我们再设置一个str[i][j] ,代表在这个dp状态对应的结果是什么。我们在更新的时候就将答案也一并更新,注意比较方式。

#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <string.h>
#include <queue>
#include <map>
using namespace std;
const int INF = 0x3f3f3f3f;
int dp[55][1101];
int n,m,k;
int num[111];
char str[55][1111][55];
bool strcmp2(char a[],char b[])
{
    if(strlen(a) != strlen(b))
    {
        return strlen(a) < strlen(b);
    }
    else
    {
        return strcmp(a,b) <0 ;
    }
}
struct Trie
{
    int next[1110][26],fail[1110],end[1101];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 26; i++)
            next[L][i] = -1;
        end[L++] = -1;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    void insert(char buf[],int id)
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0; i < len; i++)
        {
            if(next[now][buf[i] - 'a'] == -1)
                next[now][buf[i] - 'a'] = newnode();
            now = next[now][buf[i] - 'a'];
        }
        end[now] = id;;
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 26; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            for(int i = 0; i < 26; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }
    void solve(int n)
    {
        for(int i = 0; i<=n; i++)
        {
            for(int j = 0; j<L; j++)
            {
                dp[i][j] = -INF;
            }
        }
        dp[0][0] = 0;
        char ans[55];
        char tmp[55];

        strcpy(str[0][0],"");
        strcpy(ans,"");
        int Max = 0;
        for(int i = 0; i<n; i++)
        {
            for(int j = 0; j<L; j++)
            {
                if(dp[i][j] >= 0)
                {
                    strcpy(tmp,str[i][j]);
                    int len = strlen(tmp);
                    for(int k = 0; k<26; k++)
                    {
                        int _next=  next[j][k];

                        tmp[len] = 'a' + k;
                        tmp[len+1] = 0;
                        int t = dp[i][j];
                        if(end[_next] != -1)
                            t += num[end[_next]];

                        if(dp[i+1][_next] < t || (dp[i+1][_next] == t && strcmp2(tmp,str[i+1][_next])))
                        {
                            dp[i+1][_next] = t;
                            strcpy(str[i+1][_next],tmp);
                            if(t > Max || (t == Max && strcmp2(tmp,ans)))
                            {
                                Max = t;
                                strcpy(ans,tmp);
                            }
                        }

                    }
                }
            }
        }
        printf("%s\n",ans);
    }
} ac;
char buf[55];
int main()
{
    int ca;
    scanf("%d",&ca);
    while(ca--)
    {
        scanf("%d%d",&n,&m);
        ac.init();
        for(int i = 1; i<=m; i++)
        {
            scanf("%s",buf);
            ac.insert(buf,i);
        }
        for(int i = 1; i<=m; i++)
        {
            scanf("%d",&num[i]);
        }
        ac.build();
        ac.solve(n);
    }
    return 0;
}

AC自动机 + 状态压缩dp HDU - 3341

扫描二维码关注公众号,回复: 2084913 查看本文章

这道题的题意是:有一个基因序列,把它重新排列,使得这个新的序列含有题目中给的序列的个数最多。
首先设计状态,很容易想到dp[p][i][j] 代表长度为p时,在自动机第i个节点,ATGC四个碱基各加了多少个的时候含有多少个给定的基因。其实我们不用考虑第一维,因为长度的限制我们可以在第三维解决,相当于作了一个滚动数组的样子。保存状态我们用了类似hash的办法。在转移的过程中,一定要注意是否还可以用当前的枚举到的子节点。最后找出满状态下(使用了全部的字符)dp数组的最大值就行。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <bitset>
#include <queue>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;
struct Trie
{
    int next[510][26],fail[510],end[510];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 4; i++)
            next[L][i] = -1;
        end[L++] = 0;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    int ha(char x)
    {
        if(x == 'A') return 0;
        if(x == 'C') return 1;
        if(x == 'G') return 2;
        if(x =='T') return 3;
    }
    void insert(char buf[])
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0;i < len;i++)
        {
            if(next[now][ha(buf[i])] == -1)
                next[now][ha(buf[i])] = newnode();
            now = next[now][ha(buf[i])];
        }
        end[now] ++;
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 4; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            end[now] += end[fail[now]];
            for(int i = 0; i < 4; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }
    int dp[510][11*11*11*11+5];
    int bit[4];
    int num[4];
    int solve(char buf[])
    {
        int len = strlen(buf);
        memset(num,0,sizeof(num));
        for(int i = 0;i < len;i++)
            num[ha(buf[i])]++;
        bit[0] = (num[1]+1)*(num[2]+1)*(num[3]+1);
        bit[1] = (num[2]+1)*(num[3]+1);
        bit[2] = (num[3]+1);
        bit[3] = 1;
        memset(dp,-1,sizeof(dp));
        dp[root][0] = 0;
        for(int A = 0;A <= num[0];A++)
            for(int B = 0;B <= num[1];B++)
                for(int C = 0;C <= num[2];C++)
                    for(int D = 0;D <= num[3];D++)
                    {
                        int s = A*bit[0] + B*bit[1] + C*bit[2] + D*bit[3];
                        for(int i = 0;i < L;i++)
                            if(dp[i][s] >= 0)
                            {
                                for(int k = 0;k < 4;k++)
                                {
                                    if(k == 0 && A == num[0])continue;
                                    if(k == 1 && B == num[1])continue;
                                    if(k == 2 && C == num[2])continue;
                                    if(k == 3 && D == num[3])continue;
                                    dp[next[i][k]][s+bit[k]] = max(dp[next[i][k]][s+bit[k]],dp[i][s]+end[next[i][k]]);
                                }
                            }
                    }
        int ans = 0;
        int status = num[0]*bit[0] + num[1]*bit[1] + num[2]*bit[2] + num[3]*bit[3];
        for(int i = 0;i < L;i++)
            ans = max(ans,dp[i][status]);
        return ans;
    }
    void debug()
    {
        for(int i = 0; i < L; i++)
        {
            printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],end[i]);
            for(int j = 0; j < 26; j++)
                printf("%2d",next[i][j]);
            printf("]\n");
        }
    }
}ac;
char buf[1001];
int main()
{
    int cat = 1;
    int n;
    while(scanf("%d",&n) != EOF)
    {
        if(n == 0) break;
        ac.init();
        for(int i = 0;i<n;i++)
        {
            scanf("%s",buf);
            ac.insert(buf);
        }

        ac.build();
        scanf("%s",buf);
        printf("Case %d: %d\n",cat++,ac.solve(buf));

    }
    return 0;
}

AC自动机 + DP + SPFA
没猜对做法系列,看到题解的那一瞬间我整个人都崩溃了……
题意是这样的:给出一些文件串和病毒串,由01组成。构造出最短的,不包含病毒串但是包含全部文件串的串。
我们还是先设计状态 dp[state][j] 代笔走到第j个自动机上的可用节点的时候,状态为state,也就是那些文件串被使用了,的最优解。可用节点代表它自己不是病毒串的结尾,他的fail节点也不是,总之就是沿着这个节点走不会走到病毒串。有了这个前提我们再来看转移,由于可用节点不一定是直接连在一起,甚至到不了,我们没办法直接使用next数组。我们就需要在自动机上预处理出来任意两个可用节点之间的最短路。之后的就和上面的解法是相同的了。

#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <string.h>
#include <queue>
using namespace std;
const int INF = 0x3f3f3f3f;
struct Trie
{
    int next[60010][2],fail[60010],end[60010];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 2; i++)
            next[L][i] = -1;
        end[L++] = 0;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    void insert(char buf[],int id)
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0; i < len; i++)
        {
            if(next[now][buf[i]-'0'] == -1)
                next[now][buf[i]-'0'] = newnode();
            now = next[now][buf[i]-'0'];
        }
        end[now] = id;
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 2; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            if(end[fail[now]] == -1) end[now] = -1;
            for(int i = 0; i < 2; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }

    int g[11][11];
    int dp[1050][11];
    int cnt;
    int pos[11];
    int dist[1000010];

    void spfa(int k)
    {
        queue<int> q;
        memset(dist,-1,sizeof(dist));
        dist[pos[k]] = 0;
        q.push(pos[k]);
        while(!q.empty())
        {
            int now = q.front();
            q.pop();
            for(int i = 0;i<2;i++)
            {
                int tmp = next[now][i];
                if(dist[tmp] <0 && end[tmp] >= 0)
                {
                    dist[tmp] = dist[now] +1;
                    q.push(tmp);
                }
            }
        }
        for(int i = 0;i<cnt;i++)
        {
            g[k][i] = dist[pos[i]];
        }
    }

    int solve(int n)
    {
        pos[0] = 0;
        cnt = 1;
        for(int i = 0;i<L;i++)
        {
            if(end[i] > 0) pos[cnt++]  = i;
        }

        for(int i = 0;i<cnt;i++)
        {
            spfa(i);
        }

        for(int i = 0;i<(1<<n);i++)
        {
            for(int j = 0;j<cnt;j++)
            {
                dp[i][j] = INF;
            }
        }

        dp[0][0] = 0;
        for(int i = 0;i<(1<<n);i++)
        {
            for(int j = 0;j<cnt;j++)
            {
                if(dp[i][j] <INF)
                {
                    for(int k = 0;k<cnt;k++)
                    {
                        if(g[j][k] <0) continue;

                        if(j == k) continue;

                        dp[i|end[pos[k]]][k] = min(dp[i|end[pos[k]]][k],dp[i][j]+g[j][k]);
                    }
                }
            }
        }

        int ans = INF;
        for(int i = 0;i<cnt;i++)
        {
            ans = min(ans,dp[(1<<n)-1][i]);
        }
        return ans;
    }
};
char buf[1000010];
Trie ac;
int main()
{


    int n,m;
    while(scanf("%d%d",&n,&m) != EOF)
    {
        if(n == 0) break;
        ac.init();
        for(int i = 0;i<n;i++)
        {
            scanf("%s",buf);
            ac.insert(buf,1<<i);
        }

        for(int i = 0;i<m;i++)
        {
            scanf("%s",buf);
            ac.insert(buf,-1);
        }

        ac.build();
        printf("%d\n",ac.solve(n));
    }
    return 0;
}

AC自动机 + 数位DP ZOJ - 3494 神题啊!
这道题的意思是求出A和B之间,(8421)BCD码表示的数,不含有非法串的个数。
那么对于一个数位dp,我们只需要两维来表示位数和转台就行了,这里的状态是走到AC自动机的哪一个节点上。这里首先要预处理出来转移的目标,因为非法串的表示是01串。我们定义bcd[i][j] 是自动机的第i个节点,走一个数字j,走到的节点编号,在预处理的时候按照数位关系保存节点信息或者保存-1(走不到).这里AC自动机的使命就结束了。
在数位DP的时候,要特别处理一下前导0的问题,然后就和数位DP的其他套路是一样的了。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <bitset>
#include <queue>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 1000000009;
struct Trie
{
    int next[2510][2],fail[2510];
    bool end[2510];
    int root,L;
    int newnode()
    {
        for(int i = 0; i < 2; i++)
            next[L][i] = -1;
        end[L++] = 0;
        return L-1;
    }
    void init()
    {
        L = 0;
        root = newnode();
    }
    void insert(char buf[])
    {
        int len = strlen(buf);
        int now = root;
        for(int i = 0;i < len;i++)
        {
            if(next[now][buf[i]- '0'] == -1)
                next[now][buf[i] -'0'] = newnode();
            now = next[now][buf[i] - '0'];
        }
        end[now] = 1;
    }
    void build()
    {
        queue<int>Q;
        fail[root] = root;
        for(int i = 0; i < 2; i++)
            if(next[root][i] == -1)
                next[root][i] = root;
            else
            {
                fail[next[root][i]] = root;
                Q.push(next[root][i]);
            }
        while( !Q.empty() )
        {
            int now = Q.front();
            Q.pop();
            if(end[fail[now]]) end[now] = 1;
            for(int i = 0; i < 2; i++)
                if(next[now][i] == -1)
                    next[now][i] = next[fail[now]][i];
                else
                {
                    fail[next[now][i]]=next[fail[now]][i];
                    Q.push(next[now][i]);
                }
        }
    }
}ac;
int bcd[2010][10];
int getBCD(int pre,int num)
{
    if(ac.end[pre]) return -1;
    int cur = pre;
    for(int i = 3;i>= 0;i--)
    {
        if(ac.end[ac.next[cur][(num >> i) & 1]]) return -1;
        cur = ac.next[cur][(num >> i) & 1];
    }
    return cur;
}
void preproc()
{
    for(int i = 0;i<ac.L;i++)
    {
        for(int j = 0;j<10;j++)
        {
            bcd[i][j] = getBCD(i,j);
        }
    }
}
int bit[210];
ll dp[210][2010];
ll dfs(int pos,int s,bool flag,bool z)
{
    if(pos == -1) return 1;
    if(!flag  && dp[pos][s] != -1) return dp[pos][s];
    ll ans = 0;
    if(z)
    {
        ans += dfs(pos-1,s,flag && bit[pos] == 0 ,true);
        ans %= mod;
    }
    else
    {
        if(bcd[s][0] != -1) ans += dfs(pos-1,bcd[s][0],flag && bit[pos] == 0,false);
        ans %= mod;
    }

    int en = flag ? bit[pos] : 9;
    for(int i = 1;i<= en;i++)
    {
        if(bcd[s][i] != -1)
        {
            ans += dfs(pos-1,bcd[s][i],flag && i == en ,0);
            ans %= mod;
        }
    }
    if(!flag  && !z) dp[pos][s] = ans;
    return ans;
}
ll calc(char s[])
{
    int len = strlen(s);
    for(int i = 0;i<len;i++)
    {
        bit[i] = s[len -1 - i] - '0';
    }
    return dfs(len-1,0,1,1);
}
char buf[211];
int main()
{

    int ca;
    scanf("%d",&ca);
    while(ca--)
    {
        int n;
        scanf("%d",&n);
        ac.init();
        for(int i = 0;i<n;i++)
        {
            scanf("%s",buf);
            ac.insert(buf);
        }
        ac.build();

        preproc();
        memset(dp,-1,sizeof(dp));
        int ans = 0;
        scanf("%s",buf);
        int len = strlen(buf);
        for(int i = len-1;i>=0;i--)
        {
            if(buf[i] > '0')
            {
                buf[i] --;
                break;
            }
            else
            {
                buf[i] = '9';
            }
        }
        ans -= calc(buf);
        ans %=  mod;
        scanf("%s",buf);
        ans += calc(buf);
        ans %= mod;
        if(ans < 0) ans += mod;
        printf("%d\n",ans);
    }
    return 0;
}

总结:使用AC自动机作为预处理,之后套上其他的算法已经是一个很常见的题目了,AC自动机本身比较简单,但是和已经学过的其他算法结合就会变难。其实这类题都比较套路,看题目描述和输入输出就知道用什么算法了,在解AC自动机+dp题时,尤其要注意状态和转移是否是基于自动机的。

猜你喜欢

转载自blog.csdn.net/lingzidong/article/details/80714431