POJ 1625 Censored!(AC自动机 + DP + 大数 + 拓展ASCII处理)题解

题意:给出n个字符,p个病毒串,要你求出长度为m的不包含病毒串的主串的个数

思路:不给取模最恶劣情况$50^{50}$,所以用高精度板子。因为m比较小,可以直接用DP写。

因为给你的串的字符包含拓展ASCII码(128~256),由于编译器的原因,char的读入范围在-128~127或者0~255之间不确定,所以你读一个拓展ASCII码的字符后可能是负的,那么你处理的时候要注意加130。或者你直接用map映射。或者用unsigned char。

代码:

#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 100 + 5;
const int M = 50 + 5;
const ull seed = 131;
const double INF = 1e20;
const int MOD = 100000;
int n, m, p;
int tol;
int reflect[600];
struct BigInt{
    const static int mod = 10000;
    int a[50], len;
    BigInt(){
        memset(a, 0, sizeof(a));
        len = 1;
    }
    void set(int v){
        memset(a, 0, sizeof(a));
        len = 0;
        do{
            a[len++] = v % mod;
            v /= mod;
        }while(v);
    }

    BigInt operator + (const BigInt &b) const{
        BigInt res;
        res.len = max(len, b.len);
        for(int i = 0; i <= res.len; i++)
            res.a[i] = 0;
        for(int i = 0; i < res.len; i++){
            res.a[i] += ((i < len)? a[i] : 0) + ((i < b.len)? b.a[i] : 0);
            res.a[i + 1] += res.a[i] / mod;
            res.a[i] %= mod;
        }
        if(res.a[res.len] > 0) res.len++;
        return res;
    }

    BigInt operator * (const BigInt &b) const{
        BigInt res;
        for(int i = 0; i < len; i++){
            int up = 0;
            for(int j = 0; j < b.len; j++){
                int temp = a[i] * b.a[j] + res.a[i + j] + up;
                res.a[i + j] = temp % mod;
                up = temp / mod;
            }
            if(up != 0) res.a[i + b.len] = up;
        }
        res.len = len + b.len;
        while(res.a[res.len - 1] == 0 && res.len > 1) res.len--;
        return res;
    }

    void output(){
        printf("%d", a[len - 1]);
        for(int i = len - 2; i >= 0; i--){
            printf("%04d", a[i]);
        }
        printf("\n");
    }
};
BigInt dp[55][maxn];
struct Aho{
    struct state{
        int next[51];
        int fail, cnt;
    }node[maxn];
    int size;
    queue<int> q;

    void init(){
        size = 0;
        newtrie();
        while(!q.empty()) q.pop();
    }

    int newtrie(){
        memset(node[size].next, 0, sizeof(node[size].next));
        node[size].cnt = node[size].fail = 0;
        return size++;
    }

    void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i = 0; i < len; i++){
            int c = reflect[int(s[i]) + 130];
            if(node[now].next[c] == 0){
                node[now].next[c] = newtrie();
            }
            now = node[now].next[c];
        }
        node[now].cnt = 1;
    }

    void build(){
        node[0].fail = -1;
        q.push(0);

        while(!q.empty()){
            int u = q.front();
            q.pop();
            if(node[node[u].fail].cnt && u) node[u].cnt = 1;
            for(int i = 0; i < 51; i++){
                if(!node[u].next[i]){
                    if(u == 0)
                        node[u].next[i] = 0;
                    else
                        node[u].next[i] = node[node[u].fail].next[i];
                }
                else{
                    if(u == 0) node[node[u].next[i]].fail = 0;
                    else{
                        int v = node[u].fail;
                        while(v != -1){
                            if(node[v].next[i]){
                                node[node[u].next[i]].fail = node[v].next[i];
                                break;
                            }
                            v = node[v].fail;
                        }
                        if(v == -1) node[node[u].next[i]].fail = 0;
                    }
                    q.push(node[u].next[i]);
                }
            }
        }
    }

    void query(){
        BigInt one;
        one.set(1);
        for(int i = 0; i <= m; i++){
            for(int j = 0; j < size; j++){
                dp[i][j].set(0);
            }
        }
        for(int i = 0; i < tol; i++){
            if(node[node[0].next[i]].cnt == 0){
                dp[1][node[0].next[i]] = dp[1][node[0].next[i]] + one;
            }
        }
        for(int i = 1; i <= m; i++){
            for(int j = 0; j < size; j++){
                for(int k = 0; k < tol; k++){
                    if(node[node[j].next[k]].cnt == 0){
                        dp[i + 1][node[j].next[k]] = dp[i + 1][node[j].next[k]] + dp[i][j];
                    }
                }
            }
        }
        BigInt ans;
        ans.set(0);
        for(int i = 0; i < size; i++){
            if(node[i].cnt == 0){
                ans = ans + dp[m][i];
            }
        }
        ans.output();
    }

}ac;
char s[100];
int main(){
    while(~scanf("%d%d%d", &n, &m, &p)){
        scanf("%s", s);
        tol = 0;
        for(int i = 0; i < n; i++){
            reflect[int(s[i]) + 130] = tol++;
        }
        ac.init();
        while(p--){
            scanf("%s", s);
            ac.insert(s);
        }
        ac.build();
        ac.query();
    }
    return 0;
}
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 100 + 5;
const int M = 50 + 5;
const ull seed = 131;
const double INF = 1e20;
const int MOD = 100000;
int n, m, p;
int tol;
map<char, int> reflect;
struct BigInt{
    const static int mod = 10000;
    int a[50], len;
    BigInt(){
        memset(a, 0, sizeof(a));
        len = 1;
    }
    void set(int v){
        memset(a, 0, sizeof(a));
        len = 0;
        do{
            a[len++] = v % mod;
            v /= mod;
        }while(v);
    }

    BigInt operator + (const BigInt &b) const{
        BigInt res;
        res.len = max(len, b.len);
        for(int i = 0; i <= res.len; i++)
            res.a[i] = 0;
        for(int i = 0; i < res.len; i++){
            res.a[i] += ((i < len)? a[i] : 0) + ((i < b.len)? b.a[i] : 0);
            res.a[i + 1] += res.a[i] / mod;
            res.a[i] %= mod;
        }
        if(res.a[res.len] > 0) res.len++;
        return res;
    }

    BigInt operator * (const BigInt &b) const{
        BigInt res;
        for(int i = 0; i < len; i++){
            int up = 0;
            for(int j = 0; j < b.len; j++){
                int temp = a[i] * b.a[j] + res.a[i + j] + up;
                res.a[i + j] = temp % mod;
                up = temp / mod;
            }
            if(up != 0) res.a[i + b.len] = up;
        }
        res.len = len + b.len;
        while(res.a[res.len - 1] == 0 && res.len > 1) res.len--;
        return res;
    }

    void output(){
        printf("%d", a[len - 1]);
        for(int i = len - 2; i >= 0; i--){
            printf("%04d", a[i]);
        }
        printf("\n");
    }
};
BigInt dp[55][maxn];
struct Aho{
    struct state{
        int next[51];
        int fail, cnt;
    }node[maxn];
    int size;
    queue<int> q;

    void init(){
        size = 0;
        newtrie();
        while(!q.empty()) q.pop();
    }

    int newtrie(){
        memset(node[size].next, 0, sizeof(node[size].next));
        node[size].cnt = node[size].fail = 0;
        return size++;
    }

    void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i = 0; i < len; i++){
            int c = reflect[s[i]];
            if(node[now].next[c] == 0){
                node[now].next[c] = newtrie();
            }
            now = node[now].next[c];
        }
        node[now].cnt = 1;
    }

    void build(){
        node[0].fail = -1;
        q.push(0);

        while(!q.empty()){
            int u = q.front();
            q.pop();
            if(node[node[u].fail].cnt && u) node[u].cnt = 1;
            for(int i = 0; i < 51; i++){
                if(!node[u].next[i]){
                    if(u == 0)
                        node[u].next[i] = 0;
                    else
                        node[u].next[i] = node[node[u].fail].next[i];
                }
                else{
                    if(u == 0) node[node[u].next[i]].fail = 0;
                    else{
                        int v = node[u].fail;
                        while(v != -1){
                            if(node[v].next[i]){
                                node[node[u].next[i]].fail = node[v].next[i];
                                break;
                            }
                            v = node[v].fail;
                        }
                        if(v == -1) node[node[u].next[i]].fail = 0;
                    }
                    q.push(node[u].next[i]);
                }
            }
        }
    }

    void query(){
        BigInt one;
        one.set(1);
        for(int i = 0; i <= m; i++){
            for(int j = 0; j < size; j++){
                dp[i][j].set(0);
            }
        }
        for(int i = 0; i < tol; i++){
            if(node[node[0].next[i]].cnt == 0){
                dp[1][node[0].next[i]] = dp[1][node[0].next[i]] + one;
            }
        }
        for(int i = 1; i <= m; i++){
            for(int j = 0; j < size; j++){
                for(int k = 0; k < tol; k++){
                    if(node[node[j].next[k]].cnt == 0){
                        dp[i + 1][node[j].next[k]] = dp[i + 1][node[j].next[k]] + dp[i][j];
                    }
                }
            }
        }
        BigInt ans;
        ans.set(0);
        for(int i = 0; i < size; i++){
            if(node[i].cnt == 0){
                ans = ans + dp[m][i];
            }
        }
        ans.output();
    }

}ac;
char s[100];
int main(){
    while(~scanf("%d%d%d", &n, &m, &p)){
        scanf("%s", s);
        tol = 0;
        for(int i = 0; i < n; i++){
            reflect[s[i]] = tol++;
        }
        ac.init();
        while(p--){
            scanf("%s", s);
            ac.insert(s);
        }
        ac.build();
        ac.query();
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/KirinSB/p/11185424.html