Passwords Gym - 101174E (AC自动机上DP)

Problem E: Passwords

\[ Time Limit: 1 s \quad Memory Limit: 256 MiB \]

题意

给出两个正整数\(A,B\),再给出\(n\)个字符串,然后问你满足条件的字符串有多少种,最后答案\(\%1e6+3\)。条件如下
\[ \begin{aligned} 1、&长度在A到B之间\\ 2、&所有子串不存在n个字符串中任意一个\\ 3、&模式串中,把0看成o,1看成i,3看成e,5看成s,7看成t\\ 4、&至少存在一个小写字母,一个大写字母,一个数字 \end{aligned} \]
要注意一下题目中的Additionally, for the purposes of avoiding the blacklist, you cannot use \(l33t\).这句话,我一开始还以为\(l33t\)也算第\(n+1\)个字符串,也不能计数,结果发现读错了,这句话只是引出后面那么限制条件。

思路

很显然这种在多个字符串上跳来跳去的,然后给出一些限制条件,在\(AC自动机\)\(DP\)大部分情况下是\(ok\)的,而且对\(fail\)指针\(build\)时,也很常见,补全成\(Trie\)图,把\(fail[u]\)的信息传给\(u\)就可以了。
那么如何\(dp\)呢,用\(dp[state][i][j]\)\(state\)最大为\(7\),第一位为小写字母,第二问为大写字母,第三次为数字,状压状态,\(i\)表示匹配串已经到了第\(i\)位,\(j\)表示在\(AC自动机\)上的状态,然后状态就很容易得到了。
\[ \begin{aligned} dp[st][i][j] &-> dp[st|1][i+1][k] 小写字母时\\ dp[st][i][j] &-> dp[st|2][i+1][k] 大写字母时\\ dp[st][i][j] &-> dp[st|4][i+1][k] 数字时\\ \end{aligned} \]
最后输出\(dp[7][\sum_A^B][\sum_1^{sz}]\)就可以了。

/***************************************************************
    > File Name    : E.cpp
    > Author       : Jiaaaaaaaqi
    > Created Time : 2019年05月08日 星期三 15时00分56秒
 ***************************************************************/

#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define  lowbit(x)  x & (-x)
#define  mes(a, b)  memset(a, b, sizeof a)
#define  fi         first
#define  se         second
#define  pii        pair<int, int>

typedef unsigned long long int ull;
typedef long long int ll;
const int    maxn = 5e3 + 10;
const int    maxm = 1e5 + 10;
const ll     mod  = 1e6 + 3;
const ll     INF  = 1e18 + 100;
const int    inf  = 0x3f3f3f3f;
const double pi   = acos(-1.0);
const double eps  = 1e-8;
using namespace std;

int n, m, A, B;
int cas, tol, T;

map<int, int> mp;

void handle() {
    mp.clear();
    for(int i='a'; i<='z'; i++) {
        mp[i] = i-'a'+1;
    }
    for(int i='A'; i<='Z'; i++) {
        mp[i] = i-'A'+1;
    }
    int cnt = 26;
    for(int i='0'; i<='9'; i++) {
        if(i=='0')  mp[i] = mp['o'];
        else if(i == '1')   mp[i] = mp['i'];
        else if(i == '3')   mp[i] = mp['e'];
        else if(i == '5')   mp[i] = mp['s'];
        else if(i == '7')   mp[i] = mp['t'];
        else mp[i] = ++cnt;
    }
}
struct AC {
    int node[maxn][35], fail[maxn], cnt[maxn];
    ll dp[10][25][maxn];
    int root, sz;
    int newnode() {
        mes(node[++sz], 0);
        cnt[sz] = 0;
        return sz;
    }
    void init() {
        sz = 0;
        root = newnode();
    }
    void insert(char *s) {
        int len = strlen(s+1);
        int rt = root;
        for(int i=1; i<=len; i++) {
            int k = mp[s[i]];
            if(node[rt][k] == 0)    node[rt][k] = newnode();
            rt = node[rt][k];
        }
        cnt[rt] = 1;
    }
    void build() {
        queue<int> q;
        while(!q.empty())   q.pop();
        fail[root] = root;
        for(int i=1; i<=31; i++) {
            if(node[root][i] == 0) {
                node[root][i] = root;
            } else {
                fail[node[root][i]] = root;
                q.push(node[root][i]);
            }
        }
        while(!q.empty()) {
            int u = q.front();
            cnt[u] |= cnt[fail[u]];
            q.pop();
            for(int i=1; i<=31; i++) {
                if(node[u][i] == 0) {
                    node[u][i] = node[fail[u]][i];
                } else {
                    fail[node[u][i]] = node[fail[u]][i];
                    q.push(node[u][i]);
                }
            }
        }
    }
    ll solve(int A, int B) {
        for(int i=0; i<=7; i++) {
            for(int j=0; j<=B; j++) {
                for(int k=0; k<=sz; k++) {
                    dp[i][j][k] = 0;
                }
            }
        }
        dp[0][0][1] = 1;
        for(int i=0; i<B; i++) {
            for(int j=1; j<=sz; j++) {
                if(cnt[j])  continue;
                for(int st=0; st<=7; st++) {
                    if(dp[st][i][j] == 0)   continue;
                    // printf("dp[%d][%d][%d] = %lld\n", st, i, j, dp[st][i][j]);
                    for(int k='a'; k<='z'; k++) {
                        int nst = node[j][mp[k]];
                        if(cnt[nst])    continue;
                        dp[st|1][i+1][nst] += dp[st][i][j];
                        dp[st|2][i+1][nst] += dp[st][i][j];
                        dp[st|1][i+1][nst] %= mod;
                        dp[st|2][i+1][nst] %= mod;
                    }
                    for(int k='0'; k<='9'; k++) {
                        int nst = node[j][mp[k]];
                        if(cnt[nst])    continue;
                        dp[st|4][i+1][nst] += dp[st][i][j];
                        dp[st|4][i+1][nst] %= mod;
                    }
                }
            }
        }
        ll ans = 0;
        for(int i=A; i<=B; i++) {
            for(int j=1; j<=sz; j++) {
                if(dp[7][i][j] == 0)    continue;
                // printf("dp[7][%d][%d] = %lld\n", i, j, dp[7][i][j]);
                ans = (ans + dp[7][i][j])%mod;
            }
        }
        return ans;
    }
} ac;
char s[maxn];

int main() {
    handle();
    ac.init();
    scanf("%d%d", &A, &B);
    scanf("%d", &n);
    for(int i=1; i<=n; i++) {
        scanf("%s", s+1);
        ac.insert(s);
    }
    ac.build();
    ll ans = ac.solve(A, B);
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Jiaaaaaaaqi/p/10833712.html