【Codeforces 917 E】Upside Down(Ac自动机 / 后缀数组 / 哈希 / border相关 / Kmp / Exgcd)

传送门

至今写过最长的字符串题
看起来和回忆树差不多
但不能直接 k m p kmp
分成 u l c a , l c a v u\rightarrow lca,lca\rightarrow v 和穿过 l c a lca 分别考虑
前面直接照搬回忆树的做法

考虑穿过 l c a lca
求出 u l c a u\rightarrow lca 中最长的后缀满足是 S S 的前缀的长度
以及 l c a v lca\rightarrow v 最长前缀满足为 S S 后缀的长度
得到了可以直接利用 b o r d e r border l o g log 个等差数列的性质用 e x g c d exgcd 求解
这个是 O ( q l o g 2 n ) O(qlog^2n)

现在只考虑 u l c a u\rightarrow lca
考虑对 S R S^R 建出后缀数组,就是 l c a u lca\rightarrow u 匹配一段后缀
考虑求出和 l c a u lca\rightarrow u l c p lcp 最大的后缀 T T
那么 T T 中第一个满足 l e n l c p len\le lcp b o r d e r   T border\ T' 即为最长匹配
这个 T T' 显然可以和 l c a u lca\rightarrow u 匹配最长且满足条件

l c p lcp 最大可以直接找后缀数组中 r k rk 最近的
二分 r k rk 然后二分哈希比较即可

复杂度 O ( q l o g 2 n ) O(qlog^2n)

#include<bits/stdc++.h>
using namespace std;
#define cs const
#define pb push_back
#define ll long long 
#define pii pair<int,int>
#define fi first
#define se second
#define bg begin
cs int RLEN=(1<<20)|3;
inline char gc(){
	static char ibuf[RLEN],*ib,*ob;
	(ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
	return (ib==ob)?EOF:*ib++;
}
inline int read(){
	char ch=gc();
	int res=0;bool f=1;
	while(!isdigit(ch))f^=ch=='-',ch=gc();
	while(isdigit(ch))res=(res*10)+(ch^48),ch=gc();
	return f?res:-res;
}
inline ll readll(){
	char ch=gc();
	ll res=0;bool f=1;
	while(!isdigit(ch))f^=ch=='-',ch=gc();
	while(isdigit(ch))res=(res*10)+(ch^48),ch=gc();
	return f?res:-res;	
}
inline int readstring(char *s){
	char ch=gc();int l=0;
	while(isspace(ch))ch=gc();
	while(!isspace(ch)&&ch!=EOF)s[++l]=ch,ch=gc();
	return l;
}
inline char readchar(){
	char ch=gc();
	while(isspace(ch))ch=gc();return ch;
}
template<typename tp>inline void chemx(tp &a,tp b){a<b?a=b:0;}
template<typename tp>inline void chemn(tp &a,tp b){a<b?a=b:0;}
cs int N=200005;
struct Ac{
	int nxt[N][27],fail[N],bel[N],tot;
	inline void insert(char *s,int l,int id){
		int p=0;
		for(int i=1;i<=l;i++){
			int c=s[i]-'a'+1;
			if(!nxt[p][c])nxt[p][c]=++tot;
			p=nxt[p][c];
		}bel[id]=p;
	}
	vector<int> e[N];
	int in[N],out[N],dfn;
	void dfs(int u){
		in[u]=++dfn;
		for(int v:e[u])dfs(v);
		out[u]=dfn;
	}
	inline void build(){
		queue<int> q;
		for(int i=1;i<=26;i++){
			if(nxt[0][i])q.push(nxt[0][i]);
		}
		while(q.size()){
			int p=q.front();q.pop();
			for(int c=1;c<=26;c++){
				int v=nxt[p][c];
				if(!v)nxt[p][c]=nxt[fail[p]][c];
				else fail[v]=nxt[fail[p]][c],q.push(v);
			}
		}
		for(int i=1;i<=tot;i++)e[fail[i]].pb(i);
		dfs(0);
	}
	int tr[N];
	#define lb(x) (x&(-x))
	inline int go(int u,int c){return nxt[u][c];} 
	inline void update(int u,int k){int p=in[u];
		for(;p>0;p-=lb(p))tr[p]+=k;
	}
	inline int ask(int p,int res=0){for(;p<=dfn;p+=lb(p))res+=tr[p];return res;}
	inline int query(int i){int p=bel[i];
		return ask(in[p])-ask(out[p]+1);
	}
	#undef lb
}t[2];
typedef unsigned long long ul;
typedef pair<ul,ul> node;
cs node bas=node(31,79);
inline node operator +(cs node &a,cs node &b){return node(a.fi+b.fi,a.se+b.se);}
inline node operator -(cs node &a,cs node &b){return node(a.fi-b.fi,a.se-b.se);}
inline node operator *(cs node &a,cs node &b){return node(a.fi*b.fi,a.se*b.se);}
inline bool operator ==(cs node &a,cs node &b){return (a.fi==b.fi)&&(a.se==b.se);}
node sm[N],pw[N];
vector<pii> e[N];
int dep[N],fa[N][19];
int len[N],vf[N];
inline node thas(int u,int tp){
	return sm[u]-sm[tp]*pw[dep[u]-dep[tp]];
}
void dfs1(int u){
	for(int i=1;i<19;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
	for(cs pii &x:e[u])if(x.fi!=fa[u][0]){
		vf[x.fi]=x.se;
		sm[x.fi]=sm[u]*bas+node(x.se,x.se);
		fa[x.fi][0]=u,dep[x.fi]=dep[u]+1,dfs1(x.fi);
	}
}
inline int Lca(int u,int v){
	if(dep[u]<dep[v])swap(u,v);
	for(int i=18;~i;i--)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
	if(u==v)return u;
	for(int i=18;~i;i--)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
inline int find(int u,int k){
	for(int i=18;~i;i--)if(k&(1<<i))u=fa[u][i];
	return u;
}
int ans[N];
struct ask{
	int p,id,kd,coef;
};
vector<ask> q[N];
void dfs2(int u,int nd0,int nd1){
	t[0].update(nd0,1),t[1].update(nd1,1);
	for(cs ask &x:q[u]){
		ans[x.id]+=x.coef*t[x.kd].query(x.p);
	}
	for(cs pii &x:e[u])if(x.fi!=fa[u][0]){
		dfs2(x.fi,t[0].go(nd0,x.se),t[1].go(nd1,x.se));
	}
	t[0].update(nd0,-1),t[1].update(nd1,-1);
}
int buc[N];
struct arr{
	int s,t,del;
};
struct Kmp{
	int *del,*nxt,*be,n;
	inline void build(char *s,int len){
		nxt=new int [len+2]();
		del=new int [len+2]();
		be=new int [len+2]();
		del[1]=1,be[1]=1;
		for(int i=0,j=2;j<=len;j++){
			while(i&&s[i+1]!=s[j])i=nxt[i];
			if(s[i+1]==s[j])i++;
			nxt[j]=i,del[j]=j-i;
			be[j]=(del[j]==del[i])?be[i]:j;
		}
	}
	inline void get(int p,int len,arr *ret,int &num)cs{
		for(;p>len;p=nxt[p]);
		while(p>0){
			int s=be[p],t=p,d=del[p];
			ret[++num]=(arr){s,t,d};
			p=nxt[s];
		}
	}
}p1[N],p2[N];
struct Sa{
	int n,m,*a,*sa,*rk,*sa2;
	node *h;
	inline node has(int l,int r){
		return h[r]-h[l-1]*pw[r-l+1];
	}
	inline void Sort(){
		for(int i=1;i<=m;i++)buc[i]=0;
		for(int i=1;i<=n;i++)buc[rk[i]]++;
		for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
		for(int i=n;i>=1;i--)sa[buc[rk[sa2[i]]]--]=sa2[i];
	}
	inline void build(char *s,int len){
		n=len,m=26;
		a=new int[len+2]();
		sa=new int[len+2]();
		rk=new int[len+2]();
		sa2=new int[len+2]();
		h=new node[len+2]();
		for(int i=1;i<=n;i++)a[i]=s[i]-'a'+1,h[i]=h[i-1]*bas+node(a[i],a[i]);
		for(int i=1;i<=n;i++)sa2[i]=i,rk[i]=a[i];
		Sort();
		for(int i=1,pos=0;i<=n&&pos<n;i<<=1){
			pos=0;
			for(int j=n-i+1;j<=n;j++)sa2[++pos]=j;
			for(int j=1;j<=n;j++)if(sa[j]>i)sa2[++pos]=sa[j]-i;
			Sort();
			swap(rk,sa2);
			pos=1,rk[sa[1]]=1;
			for(int j=2;j<=n;j++)
			rk[sa[j]]=(sa2[sa[j]]==sa2[sa[j-1]]&&sa2[sa[j]+i]==sa2[sa[j-1]+i])?pos:++pos;
			m=pos;
		}
	}
	inline int lcp(int p,int u,int lca){
		if(u==lca)return 0;
		p=sa[p];int len=n-p+1;
		//if same return else find the lowest dif one
		if(len>=dep[u]-dep[lca]&&thas(u,lca)==has(p,p+dep[u]-dep[lca]-1))return dep[u]-dep[lca];
		for(int i=18;~i;i--)if(dep[u]-dep[lca]>(1<<i)){
			if(len<(dep[fa[u][i]]-dep[lca])||(thas(fa[u][i],lca)!=has(p,p+dep[fa[u][i]]-dep[lca]-1)))
			u=fa[u][i];
		}u=fa[u][0];
		return dep[u]-dep[lca];
	}
	inline bool check(int mid,int u,int lca){
		int l=lcp(mid,u,lca);
		if(l==n-sa[mid]+1)return 1;
		if(l==dep[u]-dep[lca])return 0;
		return vf[find(u,dep[u]-dep[lca]-l-1)]>a[sa[mid]+l];
	}
	inline void calc(int u,int lca,cs Kmp &kp,arr *ret,int &num){
		int l=1,r=n,res=0;
		while(l<=r){
			int mid=(l+r)>>1;
			if(check(mid,u,lca))res=mid,l=mid+1;
			else r=mid-1;
		}num=0;
		if(!res)return;
		int len=lcp(res,u,lca);
		int p=n-sa[res]+1;
		kp.get(p,len,ret,num);
	}
}s1[N],s2[N];
namespace chain{
	arr v1[61],v2[61];
	int cnt1,cnt2;
	inline int exgcd(int a,int b,int &x,int &y){
		if(!b){x=1,y=0;return a;}
		int t=exgcd(b,a%b,y,x);
		y-=a/b*x;return t;
	}
	inline int calc(int a,int b,int la,int lb,int k){
		int x,y;int d=exgcd(a,b,x,y);
		if(k%d!=0)return 0;
		x*=k/d,y*=k/d;
		int db=b/d,da=a/d,mx=x%db;
		if(mx<0)mx+=db;
		y+=((x-mx)/db)*da,x=mx;
		if(y<0)return 0;
		if(y>lb){
			int ti=(y-lb-1)/da+1;
			y-=ti*da,x+=ti*db;
		}
		if(x>la||x<0||y>lb||y<0)return 0;
		return min((la-x)/db+1,y/da+1);
	}
	inline int solve(arr x,arr y,int len){
		return calc(x.del,y.del,(x.t-x.s)/x.del,(y.t-y.s)/y.del,len-x.s-y.s);
	}
	inline void solve(int tt,int id,int u,int v,int lca){
		s2[id].calc(u,lca,p1[id],v1,cnt1);
		s1[id].calc(v,lca,p2[id],v2,cnt2);
		for(int i=1;i<=cnt1;i++)
		for(int j=1;j<=cnt2;j++)
		if(v1[i].s+v2[j].s<=len[id]&&v1[i].t+v2[j].t>=len[id])
		ans[tt]+=solve(v1[i],v2[j],len[id]);
	}
}
int n,m,Q;
char str[N];
int main(){
	#ifdef Stargazer
	freopen("lx.in","r",stdin);
	freopen("my.out","w",stdout);
	#endif
	n=read(),m=read(),Q=read();
	pw[0]=node(1,1);
	for(int i=1;i<N;i++)pw[i]=pw[i-1]*bas;
	for(int i=1;i<n;i++){
		int u=read(),v=read(),c=readchar()-'a'+1;
		e[u].pb(pii(v,c)),e[v].pb(pii(u,c));
	}dep[1]=1,dfs1(1);
	for(int i=1;i<=m;i++){
		len[i]=readstring(str);
		t[0].insert(str,len[i],i);
		s1[i].build(str,len[i]);
		p1[i].build(str,len[i]);
		reverse(str+1,str+len[i]+1);
		t[1].insert(str,len[i],i);
		s2[i].build(str,len[i]);
		p2[i].build(str,len[i]);
	}
	t[0].build(),t[1].build();
	for(int i=1;i<=Q;i++){
		int u=read(),v=read(),id=read();
		int lca=Lca(u,v);
		if(dep[u]-dep[lca]>=len[id])
		q[u].pb((ask){id,i,1,1}),
		q[find(u,dep[u]-dep[lca]-len[id]+1)].pb((ask){id,i,1,-1});
		if(dep[v]-dep[lca]>=len[id])
		q[v].pb((ask){id,i,0,1}),
		q[find(v,dep[v]-dep[lca]-len[id]+1)].pb((ask){id,i,0,-1});
		chain::solve(i,id,u,v,lca);
	}
	dfs2(1,0,0);
	for(int i=1;i<=Q;i++)cout<<ans[i]<<'\n';
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_42555009/article/details/106243336