【LUOGU???】WD与数列 sam 启发式合并

题目大意

  给你一个字符串,求有多少对不相交且相同的子串。

  位置不同算多对。

  \(n\leq 300000\)

题解

  先把后缀树建出来。

  DFS 整棵树,维护当前子树的 right 集合。

  合并两个集合的时候暴力枚举小的那个集合,然后在另一个集合的线段树中查询相应的信息计算贡献。

  怎么计算呢?

  如果两个位置之差 \(>\) 这两个位置的 \(lcp\)(即当前点的深度),那么贡献就是 \(lcp\),否则是位置之差。

  线段树记录区间点数和位置之和即可。

  时间复杂度:\(O(n\log^2n)\),好像能做到 \(O(n\log n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<functional>
#include<cmath>
#include<vector>
#include<assert.h>
#include<map>
using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
using std::vector;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
void open2(const char *s){
#ifdef DEBUG
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int N=300010;
int n;
int a[N];
int d[N];
ll ans;
int min(int a,int b)
{
    return a<b?a:b;
}
int max(int a,int b)
{
    return a>b?a:b;
}
namespace seg
{
    int s1[20000000];
    ll s2[20000000];
    int lc[20000000];
    int rc[20000000];
    int cnt;
    #define mid ((L+R)>>1)
    void mt(int p)
    {
        s1[p]=s1[lc[p]]+s1[rc[p]];
        s2[p]=s2[lc[p]]+s2[rc[p]];
    }
    int insert(int p,int x,int L,int R)
    {
        if(!p)
            p=++cnt;
        if(L==R)
        {
            s1[p]=1;
            s2[p]=x;
            return p;
        }
        if(x<=mid)
            lc[p]=insert(lc[p],x,L,mid);
        else
            rc[p]=insert(rc[p],x,mid+1,R);
        mt(p);
        return p;
    }
    int query1(int p,int l,int r,int L,int R)
    {
        if(!p||(l<=L&&r>=R))
            return s1[p];
        int res=0;
        if(l<=mid)
            res+=query1(lc[p],l,r,L,mid);
        if(r>mid)
            res+=query1(rc[p],l,r,mid+1,R);
        return res;
    }
    ll query2(int p,int l,int r,int L,int R)
    {
        if(!p||(l<=L&&r>=R))
            return s2[p];
        ll res=0;
        if(l<=mid)
            res+=query2(lc[p],l,r,L,mid);
        if(r>mid)
            res+=query2(rc[p],l,r,mid+1,R);
        return res;
    }
    int merge(int p1,int p2,int L,int R)
    {
        if(!p1||!p2)
            return p1+p2;
        if(L==R)
        {
            s1[p1]+=s1[p2];
            s2[p1]+=s2[p2];
            return p1;
        }
        lc[p1]=merge(lc[p1],lc[p2],L,mid);
        rc[p1]=merge(rc[p1],rc[p2],mid+1,R);
        mt(p1);
        return p1;
        
    }
}
namespace sam
{
    map<int,int> next[2*N];
    int fail[2*N];
    int len[2*N];
    int c[2*N];
    int last,cnt;
    void init()
    {
        last=cnt=1;
    }
    void append(int x,int v)
    {
        int np=++cnt;
        int p=last;
        c[np]=x;
        len[np]=len[p]+1;
        for(;p&&!next[p][v];p=fail[p])
            next[p][v]=np;
        if(!p)
            fail[np]=1;
        else
        {
            int q=next[p][v];
            if(len[q]==len[p]+1)
                fail[np]=q;
            else
            {
                int nq=++cnt;
                len[nq]=len[p]+1;
                next[nq]=next[q];
                fail[nq]=fail[q];
                fail[q]=fail[np]=nq;
                for(;p&&next[p][v]==q;p=fail[p])
                    next[p][v]=nq;
            }
        }
        last=np;
    }
    vector<int> g[2*N],*e[2*N];
    int sz[2*N];
    int rt[2*N];
    void merge(int x,int y,int l)
    {
        if(sz[x]<sz[y])
        {
            for(auto v:*e[x])
            {
                e[y]->push_back(v);
                int s1=0,_;
//              if(v-l-1>=1)
//                  ans+=(ll)l*seg::query1(rt[y],1,v-l-1,1,n);
                if(v-1>=1)
                {
                    ans+=(ll)(v-1)*(_=seg::query1(rt[y],max(1,v-l),v-1,1,n))-seg::query2(rt[y],max(1,v-l),v-1,1,n);
                    s1+=_;
                }
                if(v+1<=n)
                {
                    ans+=seg::query2(rt[y],v+1,min(n,v+l),1,n)-(ll)(v+1)*(_=seg::query1(rt[y],v+1,min(n,v+l),1,n));
                    s1+=_;
                }
//              if(v+l+1<=n)
//                  ans+=(ll)l*seg::query1(rt[y],v+l+1,n,1,n);
                ans+=(ll)l*(sz[y]-s1);
            }
            e[x]=e[y];
        }
        else
        {
            for(auto v:*e[y])
            {
                e[x]->push_back(v);
                int s1=0,_;
//              if(v-l-1>=1)
//                  ans+=(ll)l*seg::query1(rt[x],1,v-l-1,1,n);
                if(v-1>=1)
                {
                    ans+=(ll)(v-1)*(_=seg::query1(rt[x],max(1,v-l),v-1,1,n))-seg::query2(rt[x],max(1,v-l),v-1,1,n);
                    s1+=_;
                }
                if(v+1<=n)
                {
                    ans+=seg::query2(rt[x],v+1,min(n,v+l),1,n)-(ll)(v+1)*(_=seg::query1(rt[x],v+1,min(n,v+l),1,n));
                    s1+=_;
                }
//              if(v+l+1<=n)
//                  ans+=(ll)l*seg::query1(rt[x],v+l+1,n,1,n);
                ans+=(ll)l*(sz[x]-s1);
            }
        }
        sz[x]+=sz[y];
        rt[x]=seg::merge(rt[x],rt[y],1,n);
    }
    void dfs(int x)
    {
        e[x]=new vector<int>();
        if(c[x])
        {
            rt[x]=seg::insert(rt[x],c[x],1,n);
            e[x]->push_back(c[x]);
            sz[x]=1;
        }
        for(auto v:g[x])
        {
            dfs(v);
            merge(x,v,len[x]);
        }
    }
    void solve()
    {
        for(int i=2;i<=cnt;i++)
            g[fail[i]].push_back(i);
        dfs(1);
    }
}
int main()
{
    open("c");
    scanf("%d",&n);
    ans+=(ll)n*(n-1)/2;
    for(int i=1;i<=n;i++)
        a[i]=rd();
    n--;
    for(int i=1;i<=n;i++)
        a[i]=a[i]-a[i+1];
    for(int i=1;i<=n;i++)
        d[i]=a[i];
    sort(d+1,d+n+1);
    int t=unique(d+1,d+n+1)-d-1;
    for(int i=1;i<=n;i++)
        a[i]=lower_bound(d+1,d+t+1,a[i])-d;
    sam::init();
    for(int i=n;i>=1;i--)
        sam::append(i,a[i]);
    sam::solve();
    printf("%lld\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ywwyww/p/10200578.html
今日推荐