luogu4199 万径人踪灭

答案等于:关于某条对称轴对称的所有合法子序列的答案-连续的合法子序列的答案。

后面那个就是该串中回文串的数目,直接Manacher搞定。

至于前面的那个,首先考虑在某条轴的两侧有\(k\)对字符串呈轴对称相同。那么关于这条轴对称的答案就是\(2^k-1\).

注意到关于同一条轴对称的两对字符必然满足每对字符的下标和相等。那么我们可以将每根轴的编号看成是关于这根轴对称的每两个字符的下标和,定义\(f_i=\sum_{j=0}^i[s_i=s_{i-j}]\),那么\(k=\lceil\frac{f_i+1}{2}\rceil\)

我们再定义两个辅助函数:\(g_i=[s_i=a],h_i=[s_i=b]\),那么显然就会有\(f=g*g+h*h\)。做两遍多项式乘法即可。

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 1000000007
#define eps 1e-8
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

namespace My_Math{
    #define N 100000

    int fac[N+100],invfac[N+100];

    int add(int x,int y) {return x+y>=maxd?x+y-maxd:x+y;}
    int dec(int x,int y) {return x<y?x-y+maxd:x-y;}
    int mul(int x,int y) {return 1ll*x*y%maxd;}
    ll qpow(ll x,int y)
    {
        ll ans=1;
        while (y)
        {
            if (y&1) ans=mul(ans,x);
            x=mul(x,x);y>>=1;
        }
        return ans;
    }
    int inv(int x) {return qpow(x,maxd-2);}

    int C(int n,int m)
    {
        if ((n<m) || (n<0) || (m<0)) return 0;
        return mul(mul(fac[n],invfac[m]),invfac[n-m]);
    }

    int math_init()
    {
        fac[0]=invfac[0]=1;
        rep(i,1,N) fac[i]=mul(fac[i-1],i);
        invfac[N]=inv(fac[N]);
        per(i,N-1,1) invfac[i]=mul(invfac[i+1],i+1);
    }
    #undef N
}
using namespace My_Math;

namespace polynomial{
    struct complex{
        double x,y;
        complex(double _x=0.0,double _y=0.0) {x=_x;y=_y;}
    };
    
    complex operator +(complex a,complex b) {return complex(a.x+b.x,a.y+b.y);}
    complex operator -(complex a,complex b) {return complex(a.x-b.x,a.y-b.y);}
    complex operator *(complex a,complex b) {return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
    
    int r[N<<2];
    
    void calcr(int &lim,int len)
    {
        int cnt=0;
        while (lim<len) {lim<<=1;cnt++;}
        rep(i,0,lim-1)
            r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
    }
    
    void fft(int lim,complex *a,int typ)
    {
        rep(i,0,lim-1)
            if (i<r[i]) swap(a[i],a[r[i]]);
        for (int mid=1;mid<lim;mid<<=1)
        {
            complex wn=complex(cos(pi/mid),sin(pi/mid)*typ);
            int len=(mid<<1);
            for (int sta=0;sta<lim;sta+=len)
            {
                complex w=complex(1,0);
                for (int j=0;j<mid;j++,w=w*wn)
                {
                    complex x=a[sta+j],y=a[sta+j+mid]*w;
                    a[sta+j]=x+y;a[sta+j+mid]=x-y;
                }
            }
        }
        if (typ==-1)
            rep(i,0,lim-1) a[i].x/=lim;
    }
}
using namespace polynomial;
complex f[N<<2],g[N<<2],emp=complex(0,0),one=complex(1,0);
int n,m,p[N<<1];
char s[N],t[N<<1];

ll manacher(char *s)
{
    m=1;t[0]='*';t[1]='#';
    rep(i,0,n-1) {t[++m]=s[i];t[++m]='#';}
    t[++m]='%';
    //rep(i,0,m-1) putchar(t[i]);puts("");
    int mid=1,r=1;ll ans=0;
    rep(i,1,m-1)
    {
        if (i<r) p[i]=min(p[mid*2-i],r-i);
        else p[i]=1;
        while (t[i+p[i]]==t[i-p[i]]) p[i]++;
        ans=add(ans,(p[i]>>1));
        if (i+p[i]>r) {r=i+p[i];mid=i;}
    }
    return ans;
}

int main()
{
    scanf("%s",s);
    n=strlen(s);
    int lim=1;calcr(lim,n<<1);
    rep(i,0,n-1) if (s[i]=='a') f[i]=one;
    fft(lim,f,1);
    rep(i,0,lim-1) g[i]=g[i]+f[i]*f[i];
    rep(i,0,lim-1) f[i]=emp;
    rep(i,0,n-1) if (s[i]=='b') f[i]=one;
    fft(lim,f,1);
    rep(i,0,lim-1) g[i]=g[i]+f[i]*f[i];
    fft(lim,g,-1);
    ll ans=0;
    rep(i,0,lim-1) 
    {
        int tmp=(int)(g[i].x+0.5);
        tmp=(tmp+1)>>1;
        ans=add(ans,qpow(2,tmp)-1);
    }
    //cout << ans << endl;
    ans=dec(ans,manacher(s));
    printf("%lld\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/encodetalker/p/12387428.html