Loj #2479. 「九省联考 2018」制胡窜

Loj #2479. 「九省联考 2018」制胡窜

题目描述

对于一个字符串 \(S\),我们定义 \(|S|\) 表示 \(S\) 的长度。

接着,我们定义 \(S_i\) 表示 \(S\) 中第 \(i\) 个字符,\(S_{L,R}\) 表示由 \(S\) 中从左往右数,第 \(L\) 个字符到第 \(R\) 个字符依次连接形成的字符串。特别的,如果 \(L > R\) ,或者 \(L < [1, |S|]\), 或者 \(R < [1, |S|]\) 我们可以认为 \(S_{L,R}\) 为空串。

给定一个长度为 \(n\) 的仅由数字构成的字符串 \(S\),现在有 \(q\) 次询问,第 \(k\) 次询问会给出 \(S\) 的一个字符串 \(S_{l,r}\) ,请你求出有多少对 \((i, j)\),满足 \(1 \le i < j \le n\)\(i + 1 \lt j\),且 \(S_{l,r}\) 出现在 \(S_{1,i}\) 中或 \(S_{i+1, j−1}\) 中或 \(S_{j,n}\) 中。

输入格式

输入的第一行包含两个整数 \(n, q\)

第二行包含一个长度为 \(n\) 的仅由数字构成的字符串 \(S\)

接下来 \(q\) 行,每行两个正整数 \(l\)\(r\),表示此次询问的子串是 \(S_{l,r}\)

输出格式

对于每个询问,输出一个整数表示合法的数对个数。

数据范围与提示

对于所有测试数据,\(1 \le n \le 10^5\)\(1 \le q \le 3 · 10^5\)\(1 \le l \le r \le n\)

\(\\\)

感觉这道题细节贼烦人,正式考试的话估计可以刚一整场。

首先建后缀自动机,然后在使用线段树合并维护\(endpos\)集合。

询问的时候就先在\(fail\)树上倍增找到给定字符串出现的节点。然后我们将合法的\((i,j)\)二元组分为以下三种情况:

  1. \(S_{1,i}\)中出现
  2. \(S_{1,i}\)中未出现,\(S_{j,n}\)中出现
  3. \(S_{1,i},S_{j,n}\)中为出现,\(S_{i+1,j-1}\)中出现。

前两种情况很好算,找到位置最靠前以及最靠后的\(endpos\)就行了。

下面来考虑第三种情况。假设最靠前的\(endpos\)\(L\),最靠后的是\(R\),字符串长度为\(len\)。显然\(i<L,j>R-len+1\)

我们先考虑一种暴力做法:枚举\(j\in[R-len+2,n]\),然后算对于每个\(j\)有多少个可行的\(i\)。设\(<j\)的最大的\(endpos\)\(mx\),显然可行的\(i\)只与\(mx\)有关,为\(\min\{L,mx-len\}\)

理解了这个暴力做法过后正解就差不多知道了。对于线段树上每个节点,我们令每个位置的权值为其左边第一个\(endpos\)(如果没有则为\(0\)),\(sum\)为这些位置的权值和,\(rmax\)为最右边的\(endpos\)\(lempty\)为左边有多少个位置没有\(endpos\)。注意上述的信息只考虑了线段树所表示的区间,区间外的\(endpos\)不对其产生任何影响。正因为如此,在询问的时候先遍历左儿子,动态更新最右边的\(endpos\),再遍历右儿子计算答案。

道理很简单,就是要注意的边界情况有点多。。。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 200005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

int n,m;
char s[N];
int fail[N<<1],mxlen[N<<1];
int ch[N<<1][10];
int last=1,cnt=1;
int pos[N<<1],id[N<<1];
ll ss[N];
void Insert(int f,int P) {
    int p=last;
    int v=++cnt;
    pos[v]=P;
    id[P]=v;
    last=v;
    mxlen[v]=mxlen[p]+1;
    while(p&&!ch[p][f]) ch[p][f]=v,p=fail[p];
    if(!p) return fail[v]=1,void();
    int sn=ch[p][f];
    if(mxlen[sn]==mxlen[p]+1) return fail[v]=sn,void();
    int New=++cnt;
    mxlen[New]=mxlen[p]+1;
    memcpy(ch[New],ch[sn],sizeof(ch[sn]));
    fail[New]=fail[sn];
    fail[sn]=fail[v]=New;
    while(p&&ch[p][f]==sn) ch[p][f]=New,p=fail[p];
}

int fa[N<<1][20];
vector<int>e[N<<1];
int rt[N<<1];
int ls[N*50],rs[N*50];
int tag[N*50];
int emp[N*50],rmax[N*50];
ll sum[N*50];
int tot;
int lx,rx;

void update(int v,int lx,int rx) {
    sum[v]=sum[ls[v]]+sum[rs[v]];
    int mid=lx+rx>>1;
    ll R=rs[v]?emp[rs[v]]:rx-mid;
    sum[v]+=1ll*rmax[ls[v]]*R;
    if(!ls[v]||emp[ls[v]]==mid-lx+1) {
        emp[v]=mid-lx+1+R;
    } else {
        emp[v]=emp[ls[v]];
    }
    if(rs[v]) rmax[v]=rmax[rs[v]];
    else rmax[v]=rmax[ls[v]];
}

void Insert(int &v,int lx,int rx,int p) {
    v=++tot;
    tag[v]=1;
    if(lx==rx) {
        sum[v]=p;
        rmax[v]=lx;
        return ;
    }
    int mid=lx+rx>>1;
    if(p<=mid) Insert(ls[v],lx,mid,p);
    else Insert(rs[v],mid+1,rx,p);
    update(v,lx,rx);
}

int Merge(int a,int b,int lx,int rx) {
    if(!a||!b) return a+b;
    int v=++tot;
    int mid=lx+rx>>1;
    ls[v]=Merge(ls[a],ls[b],lx,mid);
    rs[v]=Merge(rs[a],rs[b],mid+1,rx);
    update(v,lx,rx);
    return v;
}

void dfs(int v) {
    for(int i=1;i<=18;i++) fa[v][i]=fa[fa[v][i-1]][i-1];
    if(pos[v]) Insert(rt[v],lx,rx,pos[v]);
    for(int i=0;i<e[v].size();i++) {
        int to=e[v][i];
        dfs(to);
        rt[v]=Merge(rt[v],rt[to],lx,rx);
    }
}

int Find(int l,int r) {
    int v=id[r];
    for(int i=18;i>=0;i--)
        if(fa[v][i]&&mxlen[fa[v][i]]>=r-l+1)
            v=fa[v][i];
    return v;
}

int query_mn(int v,int lx,int rx,int lim) {
    if(!v||rx<lim) return 0;
    if(lx==rx) return lx;
    int mid=lx+rx>>1;
    int x=query_mn(ls[v],lx,mid,lim);
    if(x) return x;
    else return query_mn(rs[v],mid+1,rx,lim);
}

int query_mx(int v,int lx,int rx) {
    if(lx==rx) return lx;
    int mid=lx+rx>>1;
    if(rs[v]) return query_mx(rs[v],mid+1,rx);
    else return query_mx(ls[v],lx,mid);
}

ll query_s(int v,int lx,int rx,int l,int r,int &L) {
    if(lx>r) return 0;
    if(rx<l) {
        L=max(L,rmax[v]);
        return 0;
    }
    if(l<=lx&&rx<=r) {
        ll x=!v?rx-lx+1:emp[v];
        ll ans=sum[v]+1ll*x*L;
        L=max(L,rmax[v]);
        return ans;
    }
    int mid=lx+rx>>1;
    return query_s(ls[v],lx,mid,l,r,L)+query_s(rs[v],mid+1,rx,l,r,L);
}

ll solve(int v,int len) {
    int mn=query_mn(rt[v],lx,rx,1),mx=query_mx(rt[v],lx,rx);
    ll ans=0;
    if(mn==mx) {
        if(mx<n) ans+=ss[n-mx-1];
        if(mn-len+1>1) ans+=ss[mn-len-1];
        ans+=1ll*(n-mx)*(mn-len);
        return ans;
    }
    if(mn<n) ans+=ss[n-mn-1];
    if(mx-len+1>1) ans+=ss[mx-len-1];
    if(mx-len+1>mn+1) ans-=ss[mx-len-mn];
    int ed=query_mn(rt[v],lx,rx,mn+len-1);
    if(ed) {
        ed=max(ed,mx-len+1);
        ans+=1ll*(n-ed)*(mn-1);
        ed--;
    } else ed=n-1;
    int st=max(mn,mx-len+1);
    if(ed>=st) {
        int L=0;
        ans+=query_s(rt[v],lx,rx,st,ed,L);
        ans-=1ll*len*(ed-st+1);
    }
    return ans;
}

int main() {
    n=Get(),m=Get();
    for(int i=1;i<=n;i++) ss[i]=ss[i-1]+i;
    lx=1,rx=n;
    scanf("%s",s+1);
    for(int i=1;i<=n;i++) Insert(s[i]-'0',i);
    for(int i=2;i<=cnt;i++) {
        e[fail[i]].push_back(i);
        fa[i][0]=fail[i];
    }
    dfs(1);
    int l,r;
    while(m--) {
        l=Get(),r=Get();
        cout<<solve(Find(l,r),r-l+1)<<"\n";
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hchhch233/p/10827770.html