poj3415 Common Substrings 后缀数组+单调栈

原题:http://poj.org/problem?id=3415

题解:求满足S = {(i, j, k) | kK, A(i, k)=B(j, k)}.的个数。就是求A,B中各自后缀lcp≥k的个数。将两个字符串(s1,s2)中间加个字符拼起来(s3),设f(A)为A字符串中后缀大于等于k的个数,f(A)可以用后缀数组+单调栈得出。对于本题一种做法是分别求f(s1),f(s2),f(s3),根据容斥原理,答案就是f(s3)-f(s2)-f(s1)。或者扫两遍,对于每个s2串分别求出s1中的答案就行了,又因为对称性,对于每个s1串分别求出s2中的答案。显然是第一种方法更加简单,这里两种方法都给出。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long 
using namespace std;
const int N=2e5+10;
char s1[N],s2[N],s3[N];
int rnk[N<<1],rnk1[N<<1],t[N],cnt[N],sa[N],h[N],tmp[N],stack[N][2];
int n,l1,l2,l3,k,top;
ll pre(char *s,int n){

	memset(rnk,0,sizeof rnk);
	memset(rnk1,0,sizeof rnk1);
	memset(t,0,sizeof t);
	
	for(int i=1;i<=n;i++) t[s[i]]++;
	for(int i=1;i<=130;i++) t[i]+=t[i-1];
	for(int i=1;i<=n;i++) rnk[i]=t[s[i]];
		
	for(int p=1,k=0;k!=n;p<<=1){
		memset(cnt,0,sizeof cnt);
		for(int i=1;i<=n;i++) cnt[rnk[i+p]]++;
		for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--) tmp[cnt[rnk[i+p]]--]=i;
			
		memset(cnt,0,sizeof cnt);
		for(int i=1;i<=n;i++) cnt[rnk[i]]++;
		for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
		memcpy(rnk1,rnk,sizeof(rnk)/2);
		k=1;rnk[sa[1]]=k;
		for(int i=2;i<=n;i++){
			if(rnk1[sa[i]]!=rnk1[sa[i-1]] || rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])k++;
				rnk[sa[i]]=k;	
		}
	}
	for(int i=1,k=0;i<=n;i++){
		if(rnk[i]==1){
			h[rnk[i]]=0;continue;
		}
		if(k) k--;
		while(s[i+k]==s[sa[rnk[i]-1]+k]) k++;
		h[rnk[i]]=k;			
	}
	ll ans=0;top=0;ll tmp=0;
	h[n+1]=0;ll sum=0;
	for(int i=1;i<=n+1;i++){
		tmp=0;
		if(h[i]<k){//分块 
			top=0;sum=0;
		}else{
			tmp++;sum+=h[i]-k+1; 
			while(top>0 && h[i]<stack[top][1]){
				tmp+=stack[top][0];
				sum-=1ll*stack[top][0]*1ll*(stack[top][1]-h[i]);
				top--;	
			}	
			stack[++top][0]=tmp;stack[top][1]=h[i];
			ans+=sum;
		} 
	}
	return ans;	
}
int main(){
//	freopen("poj3415.in","r",stdin);
	while(scanf("%d",&k)!=EOF){
		if(k==0) break;
		scanf("%s",s1+1);l1=strlen(s1+1);
		scanf("%s",s2+1);l2=strlen(s2+1);
		for(int i=1;i<=l1;i++) s3[i]=s1[i];
		s3[l1+1]='@';
		for(int i=1;i<=l2;i++) s3[l1+1+i]=s2[i];
		l3=l1+l2+1;
		printf("%lld\n",pre(s3,l3)-pre(s1,l1)-pre(s2,l2));
	}
	return 0;
}

第二种方法:

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long 
using namespace std;
const int N=2e5+10;
char s1[N],s2[N],s3[N];
int rnk[N<<1],rnk1[N<<1],t[N],cnt[N],sa[N],h[N],tmp[N],stack[N][3];
int n,l1,l2,l3,k,top;
ll sum,ans;
void pre(char *s,int n){

	memset(rnk,0,sizeof rnk);
	memset(rnk1,0,sizeof rnk1);
	memset(t,0,sizeof t);
	
	for(int i=1;i<=n;i++) t[s[i]]++;
	for(int i=1;i<=130;i++) t[i]+=t[i-1];
	for(int i=1;i<=n;i++) rnk[i]=t[s[i]];
		
	for(int p=1,k=0;k!=n;p<<=1){
		memset(cnt,0,sizeof cnt);
		for(int i=1;i<=n;i++) cnt[rnk[i+p]]++;
		for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--) tmp[cnt[rnk[i+p]]--]=i;
			
		memset(cnt,0,sizeof cnt);
		for(int i=1;i<=n;i++) cnt[rnk[i]]++;
		for(int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
		memcpy(rnk1,rnk,sizeof(rnk)/2);
		k=1;rnk[sa[1]]=k;
		for(int i=2;i<=n;i++){
			if(rnk1[sa[i]]!=rnk1[sa[i-1]] || rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])k++;
				rnk[sa[i]]=k;	
		}
	}
	for(int i=1,k=0;i<=n;i++){
		if(rnk[i]==1){
			h[rnk[i]]=0;continue;
		}
		if(k) k--;
		while(s[i+k]==s[sa[rnk[i]-1]+k]) k++;
		h[rnk[i]]=k;			
	}
}
int main(){
//	freopen("poj3415.in","r",stdin);
	while(scanf("%d",&k)!=EOF){
		if(k==0) break;
		scanf("%s",s1+1);l1=strlen(s1+1);
		scanf("%s",s2+1);l2=strlen(s2+1);
		for(int i=1;i<=l1;i++) s3[i]=s1[i];
		s3[l1+1]='#';
		for(int i=1;i<=l2;i++) s3[l1+1+i]=s2[i];
		l3=l1+l2+1;
		pre(s3,l3);
		n=l3;
		top=0;sum=0;ans=0;
		for(int i=1,tmp;i<=n;i++){
			if(h[i]<k){top=sum=0;continue;}
			tmp=0;
			if(sa[i-1]<=l1) tmp++,sum+=h[i]-k+1;
			while(top && h[i]<stack[top][1]){
				sum-=1ll*stack[top][0]*1ll*(stack[top][1]-h[i]);
				tmp+=stack[top][0];
				top--;
			}
			stack[++top][0]=tmp;stack[top][1]=h[i];
			if(sa[i]>l1) ans+=sum;
		}
		top=0;sum=0;
		
		for(int i=1,tmp;i<=n;i++){
			if(h[i]<k){top=sum=0;continue;} 
			tmp=0;
			if(sa[i-1]>l1) tmp++,sum+=h[i]-k+1;
			while(top && h[i]<stack[top][1]){
				sum-=1ll*stack[top][0]*1ll*(stack[top][1]-h[i]);		
				tmp+=stack[top][0];
				top--;
			}
			stack[++top][0]=tmp;stack[top][1]=h[i];
			if(sa[i]<=l1) ans+=sum;
		}
		printf("%lld\n",ans);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_39689721/article/details/87955797