【uoj】【美团杯2020】半前缀计数(后缀自动机)

传送门

思路:
这类计数问题有三种大体思路:

  • 统计并且去重,这可能会涉及一些容斥的东西;
  • 直接进行统计;
  • 统计第一次出现或最后一次出现的串;

这个题我们就直接考虑在每个串最后一次出现时进行统计就行,根据题目的定义:若对于一个前缀\(s_{1,...,i}\),假设我们后面拼接的串为\(s_{j,...,k}\),那么这个串出现最后一次的充要条件即为:\(s_{i+1}\not ={s_j}\),否则我们这个串不会出现在最后一次。
所以对于每个\(i\),我们统计后面有多少个不以\(s_{i+1}\)开头的本质不同的子串个数即可。
可以通过后缀自动机来实现。
详见代码:

/*
 * Author:  heyuhhh
 * Created Time:  2020/5/19 10:59:50
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#include <assert.h>
#define MP make_pair
#define fi first
#define se second
#define pb push_back
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << std::endl; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
  template <template<typename...> class T, typename t, typename... A> 
  void err(const T <t> &arg, const A&... args) {
  for (auto &v : arg) std::cout << v << ' '; err(args...); }
#else
  #define dbg(...)
#endif
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 1e6 + 5;

int n;
char s[N];

struct node{
    int ch[26];
    int len, fa;
    node(){memset(ch, 0, sizeof(ch)), len = 0;}
}dian[N << 2];
int last = 1, tot = 1;
void add(int c) {
    int p = last;
    int np = last = ++tot;
    dian[np].len = dian[p].len + 1;
    for(; p && !dian[p].ch[c]; p = dian[p].fa) dian[p].ch[c] = np;
    if(!p) dian[np].fa = 1;
    else {
        int q = dian[p].ch[c];
        if(dian[q].len == dian[p].len + 1) dian[np].fa = q;
        else {
            int nq = ++tot; dian[nq] = dian[q];
            dian[nq].len = dian[p].len + 1;
            dian[q].fa = dian[np].fa = nq;
            for(; p && dian[p].ch[c] == q; p = dian[p].fa) dian[p].ch[c] = nq;
        }
    }
}

ll cnt[26];

void run() {
    cin >> (s + 1);
    n = strlen(s + 1);
    ll ans = 0, all = 1;
    auto calc = [&]() {
        return dian[last].len - dian[dian[last].fa].len;
    };
    for (int i = n; i >= 0; i--) {
        ans += all;
        ans -= (i < n ? cnt[s[i + 1] - 'a'] : 0);
        if (i) {
            add(s[i] - 'a');
            cnt[s[i] - 'a'] += calc();
            all += calc();              
        }
    }
    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    run();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/heyuhhh/p/12964888.html