【NOI2018】洛谷4770你的名字题解(SAM+可持久化线段树合并)

版权声明:转载请注明原出处啦QAQ(虽然应该也没人转载): https://blog.csdn.net/hzk_cpp/article/details/88758956

题目:luogu4770.
题目大意:给定一个串 S S ,和 m m 组询问,每组询问包含一个串 T i T_i 和一个区间 [ l i , r i ] [l_i,r_i] ,表示询问 T i T_i 有多少个本质不同的子串不是 S [ l i , r i ] S[l_i,r_i] 的子串.
1 S , T i 5 1 0 5 , 1 T i 1 0 6 1\leq|S|,|T_i|\leq 5*10^5,1\leq \sum |T_i|\leq 10^6 .

把原问题转化为求 T i T_i 本质不同字串个数减去既是 S [ l i , r i ] S[l_i,r_i] 子串又是 T i T_i 子串的串个数,然后考虑 l i = 1 , r i = S l_i=1,r_i=|S| 的部分分.

先对 S S 建立SAM,发现并不是很好操作,我们并不能每次大力扫一遍 S S 来处理,所以考虑对每个询问的 T i T_i 建立一个SAM.现在考虑求出 T i T_i 上每个状态的一个 a n s ans 值,表示这个状态可以表示的串中最长的是 S S 子串的串长.

同时把 T i T_i 放入 S S T i T_i 的SAM上运行,当 S S 运行到某个状态时,找到这个状态的一个在parent树上深度最深且有连向下一个字符的转移边的状态转移,然后对应的在 T i T_i 上转移.

这个 68 68 分的做法很简单,但是要注意不能让跑到空状态里(即下标为 0 0 的状态).

然后考虑满分做法.很容易发现满分做法与 68 68 分做法相差在在 S S 的SAM上运行到某个状态时,发现这个状态并不能表示一个在 [ l i , r i ] [l_i,r_i] 内的串,所以考虑直接维护Right集合来判定是否能够往下跳,不能往下跳就缩小长度继续尝试.

维护Right集合可以通过线段树合并来实现,不过由于要维护任意时刻任意节点的Right集合,所以需要可持久化,实现时只需要把本来直接覆盖在一棵树上的节点直接新开节点存即可.

时空复杂度 O ( n ( Σ + log n ) ) O(n(\Sigma+\log n)) .

代码如下:

#include<bits/stdc++.h>
  using namespace std;

#define Abigail inline void
typedef long long LL;

const int N=1000000,C=26;

int n;
char c[N+9];
struct segment_tree{
  
  struct tree{
  	int s[2],sum;
  }tr[N*2*C+9];
  int cn;
  
  int new_node(int x){tr[++cn]=tree();tr[cn].sum=x;return cn;}
  
  int insert(int x,int l,int r,int k){
    if (!k) k=new_node(0);
    ++tr[k].sum;
	if (l==r) return k;
	int mid=l+r>>1;
	if (x<=mid) tr[k].s[0]=insert(x,l,mid,tr[k].s[0]);
	else tr[k].s[1]=insert(x,mid+1,r,tr[k].s[1]);
	return k;
  }
  
  int query(int L,int R,int l,int r,int k){
  	if (!k||L>R) return 0;
  	if (l==L&&r==R) return tr[k].sum;
  	int mid=l+r>>1;
  	if (R<=mid) return query(L,R,l,mid,tr[k].s[0]);
  	else if (L>mid) return query(L,R,mid+1,r,tr[k].s[1]);
  	  else return query(L,mid,l,mid,tr[k].s[0])+query(mid+1,R,mid+1,r,tr[k].s[1]);
  }
  
  int merge(int u,int v){
  	if (u==0||v==0) return u+v;
  	int k=new_node(tr[u].sum+tr[v].sum);
  	tr[k].s[0]=merge(tr[u].s[0],tr[v].s[0]);
  	tr[k].s[1]=merge(tr[u].s[1],tr[v].s[1]);
  	return k;
  }
  
}tr;

struct suffix_automaton{
  
  struct automaton{
  	int s[C],len,par;
  }tr[N*2+9];
  int cn,last,cnt[N*2+9];
  
  void build(){tr[cn=last=1]=automaton();}
  int par(int x){return tr[x].par;}
  int len(int x){return tr[x].len;}
  int son(int k,int x){return tr[k].s[x];}
  int tru(int x){return cnt[x];}
  
  void extend(int x){
  	int p=last,np=++cn;
  	tr[np]=automaton();tr[np].len=tr[p].len+1;cnt[np]=1;
  	last=np;
  	while (p&&!tr[p].s[x]) tr[p].s[x]=np,p=tr[p].par;
  	if (!p) tr[np].par=1;
  	else{
  	  int q=tr[p].s[x];
  	  if (tr[p].len+1==tr[q].len) tr[np].par=q;
  	  else{
  	  	tr[++cn]=tr[q];tr[cn].len=tr[p].len+1;cnt[cn]=0;
  	  	tr[q].par=tr[np].par=cn;
  	  	while (p&&tr[p].s[x]==q) tr[p].s[x]=cn,p=tr[p].par;
  	  }
  	}
  }
  
}s,t;

int v[N+9],q[N*2+9],rot[N*2+9];

void get_right(){
  for (int i=1;i<=s.cn;++i) ++v[s.len(i)];
  for (int i=1;i<=n;++i) v[i]+=v[i-1];
  for (int i=s.cn;i>=1;--i) q[v[s.len(i)]--]=i;
  for (int i=1;i<=s.cn;++i){
    rot[i]=tr.new_node(0);
    if (s.tru(i)) rot[i]=tr.insert(s.len(i),1,n,rot[i]);
  }
  for (int i=s.cn;i>=2;--i)
    rot[s.par(q[i])]=tr.merge(rot[s.par(q[i])],rot[q[i]]);
}

int ans[N*2+9];

LL solve(){
  int m,tp=1,sp=1,len=0,l,r;
  LL sum=0;
  scanf("%s%d%d",c+1,&l,&r);
  m=strlen(c+1);
  t.build();
  for (int i=1;i<=m;++i)
    t.extend(c[i]-'a');
  for (int i=1;i<=t.cn;++i) ans[i]=0;
  for (int i=1;i<=m;++i){
  	while (sp&&!(s.son(sp,c[i]-'a')&&tr.query(l+len,r,1,n,rot[s.son(sp,c[i]-'a')]))){
  	  --len;
  	  if (s.len(s.par(sp))>=len) sp=s.par(sp);
  	}
  	if (!sp) sp=1,len=0;
  	else sp=s.son(sp,c[i]-'a'),++len;
  	tp=t.son(tp,c[i]-'a');
  	while (t.par(tp)&&t.len(t.par(tp))>=len) tp=t.par(tp);
  	ans[tp]=max(ans[tp],len);
  }
  for (int i=1;i<=m;++i) v[i]=0;
  for (int i=1;i<=t.cn;++i) ++v[t.len(i)];
  for (int i=1;i<=m;++i) v[i]+=v[i-1];
  for (int i=t.cn;i>=1;--i) q[v[t.len(i)]--]=i;
  for (int i=t.cn;i>=2;--i)
    if (ans[q[i]]) ans[t.par(q[i])]=t.len(t.par(q[i]));
  for (int i=2;i<=t.cn;++i){
  	sum+=(LL)t.len(i)-t.len(t.par(i));
  	if (ans[i]) sum-=(LL)ans[i]-t.len(t.par(i));
  }
  return sum;
}

Abigail into(){
  scanf("%s",c+1);
  n=strlen(c+1);
}

Abigail work(){
  s.build();
  for (int i=1;i<=n;++i)
    s.extend(c[i]-'a');
  get_right();
}

Abigail getans(){
  int q;
  scanf("%d",&q);
  while (q--)
    printf("%lld\n",solve());
}

int main(){
  into();
  work();
  getans();
  return 0; 
}

猜你喜欢

转载自blog.csdn.net/hzk_cpp/article/details/88758956