bzoj5287/洛谷P4426/loj2496 毒瘤 虚树+树形dp

题目分析

首先orz一下
这是一道虽然不是数据结构,但也足够毒瘤了的毒瘤题。
原图应该是一棵树+若干非树边,题意转化为求独立集数量。
先考虑暴力做法,对于与非树边相连的点,我们暴力枚举它是选,还是不选。然后再跑一遍树形dp统计答案,方程显而易见:
f[x][0]=f[x][0]*(f[son][0]+f[son[1]);
f[x][1]=f[x][1]*f[son][0];
当然这个做法还是可以用容斥优化的,这样你就可以拿到70~80分的好分数,虽然划算但这并不是我们的重点。
那么有没有什么优美一点的做法呢?考虑我们dp的时候可以不整棵树弄一次,而是只把那些与非树边相连的点拿出来弄一弄…….让我们来建虚树吧!
那么建出虚树后,要怎么快速dp呢?
假设在虚树上有两个点x,y,y是x的父亲。在原树上,y的儿子是t,并且存在:
f[t][0]=k0[x][0]*f[x][0]+k1[x][0]*f[x][1];
f[t][1]=k0[x][1]*f[x][0]+k1[x][1]*f[x][1];
那么假设除去t这棵子树的贡献后,f(y,0)和f(y,1)分别记为orz(y,0),orz(y,1),就有:
f[y][0]=orz[y][0]*(k0[x][0]+k0[x][1])*f[x][0]+orz[y][0]*(k1[x][0]+k1[x][1])*f[x][1]);
f[y][1]=orz[y][1]*k0[x][0]*f[x][0]+orz[y][1]*k1[x][0]*f[x][1];
由此可见,k0和k1是可以进行dp的。我们可以做一次预处理,在不考虑以虚树点为根的子树的前提下弄出orz,然后再对k0和k1进行dp,最后求出来。
然后枚举所有与非树边相连的点是选还是不选,然后在虚树上dp,每次dp到y时,可以O(1)求出f(t,0)和f(t,1),就可以愉♂快地dp啦!

代码

#include<bits/stdc++.h>
using namespace std;
int read() {
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q;
}
#define RI register int
const int mod=998244353,N=100020;
int n,m,now,K,js,top,ans;
int fa[N][17],dep[N],pos[N],vis[N],u[20],v[20],eno[N<<1];
struct tree{
    int tot,h[N],ne[N<<1],to[N<<1];
    void add(int x,int y) {to[++tot]=y,ne[tot]=h[x],h[x]=tot;}
}T1,T2;

void dfs(int x,int las) {
    pos[x]=++now,dep[x]=dep[las]+1,fa[x][0]=las,vis[x]=1;
    for(RI i=1;i<=16;++i) fa[x][i]=fa[fa[x][i-1]][i-1];
    for(RI i=T1.h[x];i;i=T1.ne[i])
        if(T1.to[i]!=las) {
            if(!vis[T1.to[i]]) dfs(T1.to[i],x);
            else ++K,u[K]=x,v[K]=T1.to[i],eno[i]=eno[i^1]=1;
        }
}
int lca(int x,int y) {
    if(dep[x]<dep[y]) swap(x,y);
    for(RI i=16;i>=0;--i) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
    if(x==y) return x;
    for(RI i=16;i>=0;--i)
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int inT2[N],st[100],p[100];
bool cmp(int x,int y) {return pos[x]<pos[y];}
void buildT2() {//建虚树
    for(RI i=1;i<=K;++i) {
        if(!inT2[u[i]]) inT2[u[i]]=1,p[++js]=u[i];
        if(!inT2[v[i]]) inT2[v[i]]=1,p[++js]=v[i];
    }
    sort(p+1,p+1+js,cmp);
    if(!inT2[1]) inT2[1]=1,st[++top]=1;
    for(RI i=1;i<=js;++i) {
        int x=p[i],o=0;
        while(top) {
            o=lca(st[top],x);
            if(top>1&&dep[st[top-1]]>dep[o]) T2.add(st[top-1],st[top]),--top;
            else if(dep[st[top]]>dep[o]) {inT2[o]=1,T2.add(o,st[top]),--top;break;}
            else break;
        }
        if(st[top]!=o) inT2[o]=1,st[++top]=o;
        st[++top]=x;
    }
    while(top>1) T2.add(st[top-1],st[top]),--top;
}

int f[N][2],k0[N][2],k1[N][2],g[N][2],bin[25],bj[N];
void predp(int x,int no) {//进行暴力算法中的dp,但是以no为根的子树不考虑
    f[x][0]=f[x][1]=1,inT2[x]=1;
    for(RI i=T1.h[x];i;i=T1.ne[i]) {
        if(T1.to[i]==fa[x][0]||T1.to[i]==no||inT2[T1.to[i]]||eno[i]) continue;
        predp(T1.to[i],no);
        f[x][0]=1LL*f[x][0]*(f[T1.to[i]][0]+f[T1.to[i]][1])%mod;
        f[x][1]=1LL*f[x][1]*f[T1.to[i]][0]%mod;
    }
}
void getk(int x,int y) {//获得k0,k1的值
    k0[x][0]=1,k1[x][1]=1;
    for(RI i=x;fa[i][0]!=y;i=fa[i][0]) {
        predp(fa[i][0],i);
        int t0=k0[x][0],t1=k1[x][0];
        k0[x][0]=1LL*f[fa[i][0]][0]*(k0[x][0]+k0[x][1])%mod;
        k1[x][0]=1LL*f[fa[i][0]][0]*(k1[x][0]+k1[x][1])%mod;
        k0[x][1]=1LL*f[fa[i][0]][1]*t0%mod;
        k1[x][1]=1LL*f[fa[i][0]][1]*t1%mod;
    }
}
void prework(int x) {//预处理
    for(RI i=T2.h[x];i;i=T2.ne[i]) prework(T2.to[i]),getk(T2.to[i],x);
    f[x][0]=f[x][1]=1;
    for(RI i=T1.h[x];i;i=T1.ne[i]) {//对于没有虚树点的子树进行预处理dp
        if(inT2[T1.to[i]]||eno[i]||T1.to[i]==fa[x][0]) continue;
        predp(T1.to[i],0);
        f[x][0]=1LL*f[x][0]*(f[T1.to[i]][0]+f[T1.to[i]][1])%mod;
        f[x][1]=1LL*f[x][1]*f[T1.to[i]][0]%mod;
    }
}
void dp(int x) {//统计答案
    g[x][0]=f[x][0],g[x][1]=f[x][1];
    for(RI i=T2.h[x];i;i=T2.ne[i]) {
        int y=T2.to[i];dp(y);
        int f0=(1LL*k0[y][0]*g[y][0]%mod+1LL*k1[y][0]*g[y][1]%mod)%mod;
        int f1=(1LL*k0[y][1]*g[y][0]%mod+1LL*k1[y][1]*g[y][1]%mod)%mod;
        g[x][0]=1LL*g[x][0]*(f0+f1)%mod,g[x][1]=1LL*g[x][1]*f0%mod;
    }
    if(bj[x]==1) g[x][0]=0;
    if(bj[x]==-1) g[x][1]=0;
}

int qm(int x) {return x>=mod?x-mod:x;}
int main()
{
    int x,y;
    n=read(),m=read(),T1.tot=1;
    for(RI i=1;i<=m;++i) x=read(),y=read(),T1.add(x,y),T1.add(y,x);
    dfs(1,0),buildT2(),prework(1);
    bin[1]=1;for(RI i=2;i<=js+1;++i) bin[i]=bin[i-1]<<1;
    for(RI zt=0;zt<bin[js+1];++zt) {//枚举与非树边相连的点的状态
        for(RI i=1;i<=js;++i)
            if(zt&bin[i]) bj[p[i]]=1;
            else bj[p[i]]=-1;
        int flag=0;
        for(RI i=1;i<=K;++i) if(bj[u[i]]==1&&bj[v[i]]==1) {flag=1;break;}
        if(flag) continue;
        dp(1),ans=qm(ans+qm(g[1][0]+g[1][1]));
    }
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/litble/article/details/80437073