【BZOJ4911】[SDOI2017]切树游戏(动态dp,FWT)

【BZOJ4911】[SDOI2017]切树游戏(动态dp,FWT)

题面

BZOJ
洛谷
LOJ

题解

首先考虑如何暴力\(dp\),设\(f[i][S]\)表示当前以\(i\)节点为根节点,联通子树权值和为\(S\)的方案数,转移就是\(FWT\)的卷积,最后只需要把所有的\(f[i][k]\)全部加起来就可以得到最终的答案。
于是这样子的复杂度就是\(O(Qnmlogm)\)。但实际上转移的时候不需要\(FWT\)回来,直接拿点值表示的数组做就可以了,这样子可以少一个\(log\)
那么我们我们额外设一个变量\(S_u\)表示其子树内所有的\(f[u]\)的和。
令矩阵的每个元素都是一个长度为\(m\)的向量,向量的乘法就是每一位对应乘,加法就是每一位对应加,\(0,1\)分别表示全\(0\)、全\(1\)的向量。那么可以得到转移:
\[\begin{bmatrix}f'_u&0&f'_u\\f'_u&1&f'_u+S'_u\\0&0&1\end{bmatrix}\times \begin{bmatrix}f_v\\S_v\\1\end{bmatrix}=\begin{bmatrix}f_u\\S_u\\1\end{bmatrix}\]
其中\(f'_u,S'_u\)表示只考虑轻儿子的转移的结果,或者说只有重儿子没有转移的结果。
这样子单次矩乘的复杂度似乎是\(27m\),但是发现很多位置是\(0\),所以可以手动算一算结果:
\[\begin{bmatrix}a_1&0&b_1\\c_1&1&d_1\\0&0&1\end{bmatrix}\times\begin{bmatrix}a_2&0&b_2\\c_2&1&d_2\\0&0&1\end{bmatrix}=\begin{bmatrix}a_1a_2&0&a_1b_2+b_1\\c_1a_2+c_2&1&c_1b_2+d_2+d_1\\0&0&1\end{bmatrix}\]
这样子就只需要维护\(4\)个地方的值了,那么常数就大大的减少了。
然后我因为全部都是\(operator\),导致常数空间很大,所以就得开\(short\ int\)
本机、洛谷、LOJ都能过,\(BZOJ\ TLE\)

#include<iostream>
#include<cstdio>
using namespace std;
#define MOD 10007
#define inv2 5004
#define MAX 30030
inline int read()
{
    int x=0;bool t=false;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-')t=true,ch=getchar();
    while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
    return t?-x:x;
}
int n,m,Q,V[MAX],inv[MOD];char ch[10];
struct Number{short a,z;short v(){return z?0:a;}};
Number operator+(Number a,Number b){return (Number){(a.v()+b.v())%MOD,0};}
Number operator-(Number a,Number b){return (Number){(a.v()-b.v()+MOD)%MOD,0};};
Number operator*(Number a,Number b)
{
    int x=b.v();
    if(x)a.a=a.a*x%MOD;
    else a.z+=1;
    return a;
}
Number operator*(Number a,int b){return a*(Number){b,0};}
Number operator/(Number a,Number b)
{
    int x=b.v();
    if(x)a.a=a.a*inv[x]%MOD;
    else a.z-=1;
    return a;
}
Number operator/(Number a,int b){return a/(Number){b,0};}
struct Array{Number s[128];}f[MAX],S[MAX],pre[129];
Array operator*(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]*b.s[i];return a;}
Array operator+(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]+b.s[i];return a;}
Array operator-(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]-b.s[i];return a;}
Array operator/(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]/b.s[i];return a;}
void FWT(Array &a,int opt)
{
    for(int i=1;i<m;i<<=1)
        for(int p=i<<1,j=0;j<m;j+=p)
            for(int k=0;k<i;++k)
            {
                Number x=a.s[j+k],y=a.s[i+j+k];
                a.s[j+k]=x+y,a.s[i+j+k]=x-y;
                if(opt==-1)a.s[j+k]=a.s[j+k]*inv2,a.s[i+j+k]=a.s[i+j+k]*inv2;
            }
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
struct Matrix{Array a,b,c,d;}t[MAX<<2],tmp[MAX];
Matrix operator*(Matrix a,Matrix b){return (Matrix){a.a*b.a,a.a*b.b+a.b,a.c*b.a+b.c,a.d+b.d+a.c*b.b};}
int fa[MAX],dfn[MAX],tim,hson[MAX],size[MAX],top[MAX],bot[MAX],ln[MAX];
void dfs1(int u,int ff)
{
    fa[u]=ff;size[u]=1;
    for(int i=h[u];i;i=e[i].next)
    {
        int v=e[i].v;if(v==ff)continue;
        dfs1(v,u);size[u]+=size[v];
        if(size[v]>size[hson[u]])hson[u]=v;
    }
}
void dfs2(int u,int tp)
{
    top[u]=tp;dfn[u]=++tim,ln[tim]=u;
    if(hson[u])dfs2(hson[u],tp),bot[u]=bot[hson[u]];
    else bot[u]=u;
    for(int i=h[u];i;i=e[i].next)
        if(e[i].v!=fa[u]&&e[i].v!=hson[u])
            dfs2(e[i].v,e[i].v);
}
void dp(int u,int ff)
{
    f[u]=pre[V[u]];
    for(int i=h[u];i;i=e[i].next)
    {
        int v=e[i].v;if(v==ff)continue;dp(v,u);
        f[u]=f[u]*(f[v]+pre[0]);S[u]=S[u]+S[v];
    }
    S[u]=S[u]+f[u];
}
#define lson (now<<1)
#define rson (now<<1|1)
void Build(int now,int l,int r)
{
    if(l==r)
    {
        int u=ln[l];Array f0=pre[V[u]],s0=pre[m];
        for(int i=h[u];i;i=e[i].next)
            if(e[i].v!=hson[u]&&e[i].v!=fa[u])
                f0=f0*(f[e[i].v]+pre[0]),s0=s0+S[e[i].v];
        tmp[l]=t[now]=(Matrix){f0,f0,f0,s0+f0};
        return;
    }
    int mid=(l+r)>>1;
    Build(lson,l,mid);Build(rson,mid+1,r);
    t[now]=t[rson]*t[lson];
}
void Modify(int now,int l,int r,int p)
{
    if(l==r){t[now]=tmp[l];return;}
    int mid=(l+r)>>1;
    if(p<=mid)Modify(lson,l,mid,p);
    else Modify(rson,mid+1,r,p);
    t[now]=t[rson]*t[lson];
}
Matrix Query(int now,int l,int r,int L,int R)
{
    if(L==l&&r==R)return t[now];
    int mid=(l+r)>>1;
    if(R<=mid)return Query(lson,l,mid,L,R);
    if(L>mid)return Query(rson,mid+1,r,L,R);
    return Query(rson,mid+1,r,mid+1,R)*Query(lson,l,mid,L,mid);
}
Matrix GetTop(int x){return Query(1,1,n,dfn[top[x]],dfn[bot[x]]);}
void Modify(int u,int y)
{
    Array f0=tmp[dfn[u]].a,s0=tmp[dfn[u]].d;s0=s0-f0;
    f0=f0/pre[V[u]];f0=f0*pre[y];V[u]=y;
    tmp[dfn[u]]=(Matrix){f0,f0,f0,s0+f0};
    while(u)
    {
        Matrix a=GetTop(u);
        Modify(1,1,n,dfn[u]);
        Matrix b=GetTop(u);
        u=fa[top[u]];if(!u)break;int x=dfn[u];
        f0=tmp[x].a;s0=tmp[x].d;s0=s0-f0;
        f0=f0/(a.c+pre[0]);f0=f0*(b.c+pre[0]);s0=s0-a.d;s0=s0+b.d;
        tmp[x]=(Matrix){f0,f0,f0,s0+f0};
    }
}
int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;++i)V[i]=read();
    for(int i=1,u,v;i<n;++i)u=read(),v=read(),Add(u,v),Add(v,u);
    inv[0]=inv[1]=1;for(int i=2;i<MOD;++i)inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
    for(int i=0;i<m;++i)pre[i].s[i]=(Number){1,0},FWT(pre[i],1);
    dfs1(1,0);dfs2(1,1);dp(1,0);Build(1,1,n);
    Q=read();
    while(Q--)
    {
        scanf("%s",ch);
        if(ch[0]=='Q')
        {
            int k=read();
            Array ans=GetTop(1).d;
            for(int i=0;i<m;++i)ans.s[i].a=ans.s[i].v();
            FWT(ans,-1);
            printf("%d\n",ans.s[k].v());
        }
        else
        {
            int x=read(),y=read();
            Modify(x,y);
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/cjyyb/p/10570636.html