str2int HDU - 4436 (后缀自动机)

str2int

\[ Time Limit: 3000 ms\quad Memory Limit: 131072 kB \]

题意

算是自己写出来的第一题吧。
给出 \(n\) 个串,求出这 \(n\) 个串所有子串代表的数字的和。

思路

首先可以把这些串合并起来,串与串之间用没出现过的字符分隔开来,然后构建后缀自动机,因为后缀自动机上从 \(root\) 走到的任意节点都是一个子串,所有可以利用这个性质来做。

一开始我的做法是做 \(dfs\),令 \(dp[i]\) 表示节点 \(i\) 的贡献,转移就是 \(dp[v] = dp[v]+tmp*10+j\),表示从 \(root\)\(u\) 的权值是\(tmp\)\(v\)\(u\)\(j\)走的下一个点。结果显然超时了。

我们发现对于\(dp[u]->dp[v]\)过程,如果之前走到 \(dp[u]\) 的有 \(12\)\(2\) 两步,假设现在往 \(3\) 这条边走,得到 \(12*10+3\)\(2*10+3\),那么其实这些值的贡献是可以一次性计算的,无论之前走到 \(dp[u]\) 的有几条路,都需要让他们全部 \(*10\),而 \(3\) 的贡献则是由走到 \(dp[u]\) 的路径数确定的。

那么我们就可以得到第二个方程:

  1. \(dp1[i]\) 表示节点 \(i\) 的贡献
  2. \(dp2[i]\) 表示之前有多少种方案走到 \(i\)
  3. \(dp1[v] = dp1[v] + dp1[u]*10 + dp2[u]*j\)
  4. \(dp2[v] = dp[2[v] + dp2[v]\)

最后为了去除前导零,只要控制从 \(root\) 出来的边最少从 \(1\) 开始就可以了。
如此计算后,\(\sum dp1[i]\) 就是最后的答案。

#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define  lowbit(x)  x & (-x)
#define  mes(a, b)  memset(a, b, sizeof a)
#define  fi         first
#define  se         second
#define  pii        pair<int, int>
#define  INOPEN     freopen("in.txt", "r", stdin)
#define  OUTOPEN    freopen("out.txt", "w", stdout)

typedef unsigned long long int ull;
typedef long long int ll;
const int    maxn = 2e5 + 10;
const int    maxm = 1e5 + 10;
const ll     mod  = 2012;
const ll     INF  = 1e18 + 100;
const int    inf  = 0x3f3f3f3f;
const double pi   = acos(-1.0);
const double eps  = 1e-8;
using namespace std;

int n, m;
int cas, tol, T;

struct Sam {
    struct Node {
        int next[20];
        int fa, len;
        void init() {
            mes(next, 0);
            fa = len = 0;
        }
    } node[maxn<<1];
    int dp1[maxn<<1], dp2[maxn<<1];
    bool vis[maxn<<1];
    int tax[maxn<<1], gid[maxn<<1];
    int last, sz;
    void init() {
        mes(dp1, 0);
        mes(dp2, 0);
        last = sz = 1;
        node[sz].init();
    }
    void insert(int k) {
        int p = last, np = last = ++sz;
        node[np].init();
        node[np].len = node[p].len + 1;
        for(; p&&!node[p].next[k]; p=node[p].fa)
            node[p].next[k] = np;
        if(p == 0) {
            node[np].fa = 1;
        } else {
            int q = node[p].next[k];
            if(node[q].len == node[p].len+1) {
                node[np].fa = q;
            } else {
                int nq = ++sz;
                node[nq] = node[q];
                node[nq].len = node[p].len+1;
                node[np].fa = node[q].fa = nq;
                for(; p&&node[p].next[k]==q; p=node[p].fa)
                    node[p].next[k] = nq;
            }
        }
    }
    void solve() {
        int ans = 0;
        for(int i=0; i<=sz; i++)    tax[i] = 0;
        for(int i=1; i<=sz; i++)    tax[node[i].len]++;
        for(int i=1; i<=sz; i++)    tax[i] += tax[i-1];
        for(int i=1; i<=sz; i++)    gid[tax[node[i].len]--] = i;
        dp2[1] = 1;
        for(int i=1; i<=sz; i++) {
            int u = gid[i];
            ans = (ans+dp1[u])%mod;
//          printf("%d %d %d\n", u, dp1[u], dp2[u]);
            for(int j=(u==1 ? 1:0); j<=9; j++) {
                if(node[u].next[j+1] == 0)  continue;
                int nst = node[u].next[j+1];
                dp1[nst] = (dp1[nst] + dp1[u]*10 + j*dp2[u])%mod;
                dp2[nst] = (dp2[nst] + dp2[u])%mod;
            }
        }
        printf("%d\n", ans);
    }
} sam;
char s[maxn], t[maxn];

int main() {
    while(~scanf("%d", &T)) {
        mes(s, 0);
        n = 0;
        while(T--) {
            scanf("%s", t+1);
            int tlen = strlen(t+1);
            for(int i=1; i<=tlen; i++) {
                s[++n] = t[i]-'0'+1;
            }
            s[++n] = 11;
        }
        sam.init();
        for(int i=1; i<=n; i++) {
            sam.insert(s[i]);
        }
        sam.solve();
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Jiaaaaaaaqi/p/10971119.html
今日推荐