【题解】Asterisk Substrings Codeforces 1276F 后缀自动机 树链的并

第一道独立解决的Div1F,嘿嘿,幸好没看题解


把串分为以下几类

不包含star的串

太简单,略

star在最前面的串

star在最后面的串

单独一个star

答案++

单独一个空串

答案++

star在中间的串

注意到,假设star的位置是pos,实际上相当于选择一个右端点为pos-1的串s1,再选择一个左端点为pos+1的串s2,问这样的pair(s1,s2)有多少个

也就是选两个原串的子串,并且这两个子串要满足上面那个条件,问方案数

对原串的n-2个字符SAM,称为sam

对原串的n-2个字符倒过来再建SAM,称为rsam

注意到,SAM上每个点表示的本质不同子串数量是len[u] - len[pa[u]],其中len是点u所表示字符串的最长长度,pa是点u在后缀树上的父亲,记这个值为val[u]

也就是说,问题变成了:枚举sam里面的一个点u,枚举rsam里面的一个点v,如果vend_pos集合存在一个数字,等于uend_pos集合里面的某个数字+2,那么ans += val[u] * val[v]

考虑sam里面的每个点u,假设uend_pos集合是{a1, a2, a3, ..., ak},那么在rsam里面,有哪些点可以和u产生贡献?所有end_pos集合包含某一个ai+2的点可以和u产生贡献,这在rsam的后缀树上,是k条树链的并

sam的后缀树上跑DSU on Tree,维护上述end_pos集合,并时刻维护集合中所有点的树链的并

总复杂度两个log

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int N = 200010;
int _w;

struct SAM {
    int ch[N][26];
    int len[N];
    int pa[N];
    int idx;
    
    void init() {
        memset(ch, 0, sizeof ch);
        memset(len, 0, sizeof len);
        memset(pa, 0, sizeof pa);
        idx = 1;
        pa[0] = -1;
    }
    int append( int p, int c ) {
        int np = idx++;
        len[np] = len[p] + 1;
        while( p != -1 && !ch[p][c] )
            ch[p][c] = np, p = pa[p];
        if( p == -1 ) pa[np] = 0;
        else {
            int q = ch[p][c];
            if( len[q] == len[p] + 1 ) pa[np] = q;
            else {
                int nq = idx++;
                memcpy(ch[nq], ch[q], sizeof ch[nq]);
                len[nq] = len[p] + 1;
                pa[nq] = pa[q];
                pa[q] = pa[np] = nq;
                while( p != -1 && ch[p][c] == q )
                    ch[p][c] = nq, p = pa[p];
            }
        }
        return np;
    }
};

int n;
char str[N];
SAM sam, rsam;

ll solve_origin() {
    sam.init();
    int last = 0;
    for( int i = 1; i <= n; ++i )
        last = sam.append(last, str[i] - 'a');
    ll ans = 0;
    for( int i = 1; i < sam.idx; ++i )
        ans += sam.len[i] - sam.len[sam.pa[i]];
    return ans;
}

ll solve_before() {
    sam.init();
    int last = 0;
    for( int i = 2; i <= n; ++i )
        last = sam.append(last, str[i] - 'a');
    ll ans = 0;
    for( int i = 1; i < sam.idx; ++i )
        ans += sam.len[i] - sam.len[sam.pa[i]];
    return ans;
}

ll solve_after() {
    sam.init();
    int last = 0;
    for( int i = 1; i <= n-1; ++i )
        last = sam.append(last, str[i] - 'a');
    ll ans = 0;
    for( int i = 1; i < sam.idx; ++i )
        ans += sam.len[i] - sam.len[sam.pa[i]];
    return ans;
}

struct Graph {
    int head[N], nxt[N], to[N], eid;
    void init() {
        eid = 0;
        memset(head, -1, sizeof head);
    }
    void link( int u, int v ) {
        to[eid] = v, nxt[eid] = head[u], head[u] = eid++;
    }
};
Graph g, rg;

namespace HLD {
    int dfn[N], dfnc, top[N], dep[N];
    int pa[N], sz[N], son[N], val[N];
    int rdfn[N];
    
    void dfs1( int u, int fa, int d ) {
        sz[u] = 1, dep[u] = d, pa[u] = fa;
        val[u] = rsam.len[u];
        for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
            int v = rg.to[i];
            dfs1(v, u, d+1);
            sz[u] += sz[v];
            if( son[u] == -1 || sz[v] > sz[son[u]] )
                son[u] = v;
        }
    }
    void dfs2( int u, int tp ) {
        dfn[u] = ++dfnc, top[u] = tp;
        rdfn[dfnc] = u;
        if( son[u] != -1 )
            dfs2( son[u], tp );
        for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
            int v = rg.to[i];
            if( v != son[u] )
                dfs2(v, v);
        }
    }
    void init() {
        memset(son, -1, sizeof son);
        dfs1(0, -1, 1);
        dfs2(0, 0);
    }
    int lca( int u, int v ) {
        while( top[u] != top[v] ) {
            if( dep[top[u]] < dep[top[v]] )
                swap(u, v);
            u = pa[top[u]];
        }
        return dep[u] < dep[v] ? u : v;
    }
}

int mark[N], rmark[N], rmark2nod[N];
ll solve_ans = 0, now = 0;
set<int> st;

void ins_node( int u ) {
    u = mark[u];
    if( !u ) return;
    u = rmark2nod[u+2];
    u = HLD::dfn[u];
    if( st.empty() ) {
        st.insert(u);
        u = HLD::rdfn[u];
        now += HLD::val[u];
    } else {
        auto after = st.lower_bound(u);
        auto before = after;
        --before;
        if( after == st.end() ) {
            int L = *before;
            L = HLD::rdfn[L];
            u = HLD::rdfn[u];
            int lca = HLD::lca(L, u);
            now -= HLD::val[lca];
            now += HLD::val[u];
            u = HLD::dfn[u];
            st.insert(u);
        } else if( after == st.begin() ) {
            int R = *after;
            R = HLD::rdfn[R];
            u = HLD::rdfn[u];
            int lca = HLD::lca(R, u);
            now -= HLD::val[lca];
            now += HLD::val[u];
            u = HLD::dfn[u];
            st.insert(u);
        } else {
            int L = *before;
            int R = *after;
            L = HLD::rdfn[L];
            R = HLD::rdfn[R];
            now += HLD::val[HLD::lca(L, R)];
            u = HLD::rdfn[u];
            now -= HLD::val[HLD::lca(L, u)];
            now -= HLD::val[HLD::lca(R, u)];
            now += HLD::val[u];
            u = HLD::dfn[u];
            st.insert(u);
        }
    }
}

void ins_tree( int u ) {
    ins_node(u);
    for( int i = g.head[u]; ~i; i = g.nxt[i] )
        ins_tree( g.to[i] );
}

int sz[N], son[N];

void init_sack( int u ) {
    sz[u] = 1, son[u] = -1;
    for( int i = g.head[u]; ~i; i = g.nxt[i] ) {
        int v = g.to[i];
        init_sack(v);
        sz[u] += sz[v];
        if( son[u] == -1 || sz[v] > sz[son[u]] )
            son[u] = v;
    }
}

void sack( int u, bool clr ) {
    // printf( "u = %d\n", u );
    for( int i = g.head[u]; ~i; i = g.nxt[i] )
        if( g.to[i] != son[u] )
            sack( g.to[i], true );
    if( son[u] != -1 )
        sack( son[u], false );
    for( int i = g.head[u]; ~i; i = g.nxt[i] )
        if( g.to[i] != son[u] )
            ins_tree( g.to[i] );
    ins_node(u);
    // printf( "u = %d, now = %lld\n", u, now );
    if( u )
        solve_ans += 1LL * now * (sam.len[u] - sam.len[sam.pa[u]]);
    if( clr ) st.clear(), now = 0;
}

ll solve() {
    sam.init();
    int last = 0;
    for( int i = 1; i <= n-2; ++i )
        last = sam.append(last, str[i] - 'a');
    g.init();
    for( int i = 1; i < sam.idx; ++i )
        g.link( sam.pa[i], i );
    last = 0;
    for( int i = 1; i <= n-2; ++i ) {
        last = sam.ch[last][str[i] - 'a'];
        mark[last] = i;
    }
    
    rsam.init();
    last = 0;
    for( int i = n; i >= 3; --i )
        last = rsam.append(last, str[i] - 'a');
    rg.init();
    for( int i = 1; i < rsam.idx; ++i )
        rg.link( rsam.pa[i], i );
    last = 0;
    for( int i = n; i >= 3; --i ) {
        last = rsam.ch[last][str[i] - 'a'];
        rmark[last] = i;
        rmark2nod[i] = last;
    }
    
    HLD::init();
    init_sack(0);
    sack(0, false);
    return solve_ans;
}

int main() {
    _w = scanf( "%s", str+1 );
    n = (int)strlen(str+1);
    ll ans = 0;
    ans += solve_origin();
    // printf( "after origin = %lld\n", ans );
    if( n >= 2 ) {
        ans += solve_before();
        ans += solve_after();
    }
    // printf( "before after = %lld\n", ans );
    if( n >= 3 ) {
        ans += solve();
    }
    printf( "%lld\n", ans+2 );
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/mlystdcall/p/12315564.html