【AC自动机】【树状数组】【dfs序】洛谷 P2414 [NOI2011]阿狸的打字机 题解

     这一题是对AC自动机的充分理解和树dfs序的巧妙运用。

题目背景

阿狸喜欢收藏各种稀奇古怪的东西,最近他淘到一台老式的打字机。

题目描述

打字机上只有28个按键,分别印有26个小写英文字母和'B'、'P'两个字母。经阿狸研究发现,这个打字机是这样工作的:

  • 输入小写字母,打字机的一个凹槽中会加入这个字母(这个字母加在凹槽的最后)。
  • 按一下印有'B'的按键,打字机凹槽中最后一个字母会消失。
  • 按一下印有'P'的按键,打字机会在纸上打印出凹槽中现有的所有字母并换行,但凹槽中的字母不会消失。

例如,阿狸输入aPaPBbP,纸上被打印的字符如下:

a aa ab 我们把纸上打印出来的字符串从1开始顺序编号,一直到n。打字机有一个非常有趣的功能,在打字机中暗藏一个带数字的小键盘,在小键盘上输入两个数\((x,y)\)(其中\(1≤x,y≤n\)),打字机会显示第\(x\)个打印的字符串在第\(y\)个打印的字符串中出现了多少次。

阿狸发现了这个功能以后很兴奋,他想写个程序完成同样的功能,你能帮助他么?

输入输出格式

输入格式:

输入的第一行包含一个字符串,按阿狸的输入顺序给出所有阿狸输入的字符。

第二行包含一个整数\(m\),表示询问个数。

接下来\(m\)行描述所有由小键盘输入的询问。其中第\(i\)行包含两个整数\(x,y\),表示第\(i\)个询问为\((x,y)\)。

输出格式:

输出\(m\)行,其中第\(i\)行包含一个整数,表示第\(i\)个询问的答案。

输入输出样例

输入样例#1:
aPaPBbP
3
1 2
1 3
2 3
输出样例#1:
2
1
0

说明

题解:

    AC自动机学习笔记中,我提到了fail树的每一条链上的点代表的字母都是一样的。实际上,根据fail和nxt的定义,在这条链上深度较深的点总包含深度较浅的点。更确切地说,深度较浅的点总是深度较深的点的后缀。

    有了上面的结论,我们就有了一种朴素算法:建立AC自动机,读取询问时让主串向父亲跳来检验fail指针指向的节点是不是模式串的结尾。接着沿fail指针检验一趟看是不是模式串的结尾,比KMP还要暴力一点。换句话说,就是个用AC自动机实现的KMP,而且代码巨短

Code 40pts:

#include<cstdio>
#include<cstring>
struct node
{
    node *fa,*ch[26];
    node *fail;
    node(node *fa)
    {
        this->fa=fa;
        memset(ch,0,sizeof(ch));
        fail=NULL;
    }
    node()
    {
        fa=NULL;
        fail=NULL;
        memset(ch,0,sizeof(ch));
    }
}*root=new node(),*End[100100];
char s[100100];
int cnt=0;
node *q[100100];
int l=0,r=0;
void Fail()
{
    root->fail=root;
    for(int i=0;i<26;++i)
        if(root->ch[i])
        {
            root->ch[i]->fail=root;
            q[++r]=root->ch[i];
        }
        else
            root->ch[i]=root;
    while(l<r)
    {
        node *p=q[++l];
        for(int i=0;i<26;++i)
            if(p->ch[i])
            {
                p->ch[i]->fail=p->fail->ch[i];
                q[++r]=p->ch[i];
            }
            else
                p->ch[i]=p->fail->ch[i];
    }
    return;
}
int main()
{
    scanf("%s",s);
    int l=strlen(s);
    node *now=root;
    for(int i=0;i<l;++i)
        if(s[i]=='P')
            End[++cnt]=now;
        else if(s[i]=='B')
            now=now->fa;
        else
        {
            if(!now->ch[s[i]-'a'])
                now->ch[s[i]-'a']=new node(now);
            now=now->ch[s[i]-'a'];
        }
    Fail();
    int u,v,n;
    scanf("%d",&n);
    while(n--)
    {
        scanf("%d%d",&u,&v);
        node *now=End[v];
        int sum=0;
        while(now!=root)
        {
            node *p=now;
            while(p!=root)
            {
                if(p==End[u])
                    ++sum;
                p=p->fail;
            }
            now=now->fa;
        }
        printf("%d\n",sum);
    }
    return 0;
}

    发现上面这种做法需要递归fail指针,我们可以考虑构建fail树,设主串的结束节点为\(a\),模式串的结束节点为\(b\),那么看fail树中\(b\)的子树中有多少个节点在主串上。(注意这里不是主串结尾,在主串中间也可以)因为子树上的点通过fail边总可以连接到\(b\)这个节点上来,因此整棵子树都可以对\(b\)做出贡献。可以稍微优化一点时间复杂度,但是得分还是在40到50左右。

    在字典树上,有一种肥肠玄妙的优化叫树状数组。我们在这个题中只需要对字典树进行一遍DFS就可以了,在DFS过程中就可以处理各种询问,有点像tarjan求LCA算法的步骤。

    首先我们明确fail树和字典树上的节点是一一对应的。但是fail树需要一个dfs序,这里不要和节点编号或者节点指针搞混了;先对fail树进行dfs,存下每个点的编号(dfn)和子树大小(sz),方便之后对子树进行统计。

    接着对字典树进行dfs,遇到一个节点就把这个节点的dfn在树状数组中+1,表明这个节点在当前访问的串中。当访问到一个作为结束点的节点时,就处理它作为主串的询问。此时有了上面一点点优化(可达50分)的思想,直接查询每个询问的模式串结束节点的子树中有多少个被+1了,因为是子树,在dfs序上是连续的,所以树状数组求区间和就可以了。回溯时不要忘了把这个点的dfn在树状数组中-1,表示不属于接下来访问的串。

    这个题写起来真爽——虽然调的也很爽

Code:

#include<cstdio>
#include<cstring>
#include<vector>
using std::vector;
struct node//AC自动机
{
    node *ch[26],*fail,*fa;
    vector<int> End;//是哪些字符串的结束点
    int sz,num;//dfs序用的子树大小和自己的编号(dfn)
    int head,ahead;
    bool exi[26];
    node(node *fa)
    {
        memset(exi,0,sizeof(exi));
        this->fa=fa;
        head=-1;
        ahead=-1;
        fail=NULL;
        memset(ch,0,sizeof(ch));
    }
    node()
    {
        memset(exi,0,sizeof(exi));
        head=-1;
        ahead=-1;
        fa=NULL;
        fail=NULL;
        memset(ch,0,sizeof(ch));
    }
}*root=new node(),*End[100100];
//存fail树的边
struct edge
{
    node *n;
    int nxt,num;
    edge(node *n,int nxt,int num)
    {
        this->n=n;
        this->nxt=nxt;
        this->num=num;
    }
    edge(node *n,int nxt)
    {
        this->n=n;
        this->nxt=nxt;
    }
    edge(){}
}e[201000],ask[101000];
int ecnt=-1,acnt=-1;
void add(node *from,node *to)
{
    e[++ecnt]=edge(to,from->head);
    from->head=ecnt;
    e[++ecnt]=edge(from,to->head);
    to->head=ecnt;
}
void Add(node *from,node *to,int num)
{
    ask[++acnt]=edge(to,from->ahead,num);
    from->ahead=acnt;
}
//树状数组
int c[100100];
int lowbit(int x)
{
    return x&(-x);
}
int dcnt=0;
void change(int x,int v)
{
    while(x<=dcnt)
    {
        c[x]+=v;
        x+=lowbit(x);
    }
    return;
}
int sum(int x)
{
    int ans=0;
    while(x)
    {
        ans+=c[x];
        x-=lowbit(x);
    }
    return ans;
}
//求Fail
node *q[100100];
int l=0,r=0;
void Fail()
{
    root->fail=root;
    for(int i=0;i<26;++i)
        if(root->ch[i])
        {
            root->ch[i]->fail=root;
            add(root,root->ch[i]);
            q[++r]=root->ch[i];
        }
        else
            root->ch[i]=root;
    while(l<r)
    {
        node *p=q[++l];
        for(int i=0;i<26;++i)
            if(p->ch[i])
            {
                p->ch[i]->fail=p->fail->ch[i];
                add(p->ch[i],p->fail->ch[i]);
                q[++r]=p->ch[i];
            }
            else
                p->ch[i]=p->fail->ch[i];
    }
    return;
}
//dfs求dfs序
void dfs(node *x,node *from)
{
    x->sz=1;
    x->num=++dcnt;
    for(int i=x->head;~i;i=e[i].nxt)
        if(e[i].n!=from)
        {
            dfs(e[i].n,x);
            x->sz+=e[i].n->sz;
        }
    return;
}
//dfs求答案
int ans[100100];
void Dfs(node *x)
{
    change(x->num,1);
    for(int i=x->ahead;~i;i=ask[i].nxt)
    {
        node *p=ask[i].n;
        ans[ask[i].num]=sum(p->num+p->sz-1)-sum(p->num-1);
    }
    for(int i=0;i<26;++i)
        if(x->exi[i])
            Dfs(x->ch[i]);
    change(x->num,-1);
}
char s[101000];
int main()
{
    scanf("%s",s);
    node *now=root;
    int cnt=0;
    for(int i=0;s[i]!='\0';++i)
        if(s[i]=='P')
        {
            now->End.push_back(++cnt);
            End[cnt]=now;
        }
        else if(s[i]=='B')
            now=now->fa;
        else
        {
            if(!now->ch[s[i]-'a'])
                now->ch[s[i]-'a']=new node(now);
            now->exi[s[i]-'a']=1;
            now=now->ch[s[i]-'a'];
        }
    Fail();
    dfs(root,root);
    int n,u,v;
    scanf("%d",&n);
    for(int i=1;i<=n;++i)
    {
        scanf("%d%d",&u,&v);
        Add(End[v],End[u],i);
    }
    Dfs(root);
    for(int i=1;i<=n;++i)
        printf("%d\n",ans[i]);
    return 0;
}

  

猜你喜欢

转载自www.cnblogs.com/wjyyy/p/lg2414.html