D. Tree Elimination(树形dp)

http://codeforces.com/problemset/problem/1276/D

题意:

给出一棵树的边,按照给出的顺序决策所有边。如果两个点都没被选择,选一个。否则不操作。问选择的序列的可能情况数。

解析:

因为一条边的决策只会影响相邻的边的决策,所以可以树形dp,从下到上,父亲从儿子转移。

d p [ 0 ] dp[0] 表示这个点被父亲边之前的儿子边选中;
d p [ 1 ] dp[1] 表示这个点被父亲边选中;
d p [ 2 ] dp[2] 表示这个点被父亲边之后的儿子边选中或者没有被选中;

在这里插入图片描述
考虑u从儿子转移(假设父亲边序号为P):

  • d p [ u ] [ 0 ] = i d v < P d p [ v ] [ 2 ] i d p r e < i d v d p [ p r e ] [ 0 / 1 ] i d n e x > i d v d p [ n e x ] [ 0 / 2 ] dp[u][0]=\sum_{id_v<P}dp[v][2]\cdot\prod_{id_{pre}<id_v}dp[pre][0/1]\cdot\prod_{id_{nex}>id_v}dp[nex][0/2]
    • 选择一个 v v 儿子,此时v是0,u是0,决策后v是0,u是1。v要求是0,所以不能选择 d p [ v ] [ 0 ] dp[v][0] ;uv边决策选择u,所以也不是 d p [ v ] [ 1 ] dp[v][1] 。只能选择 d p [ v ] [ 2 ] dp[v][2]
    • 对于v之前的儿子pre,要求保持u为0,所以可以选择 d p [ p r e ] [ 0 / 1 ] dp[pre][0/1] ,但是 d p [ p r e ] [ 2 ] dp[pre][2] 说明u已经选选择,所以不能选
    • 对于v之后的儿子nex,要求u已经是1,所以可以选择 d p [ n e x ] [ 0 / 2 ] dp[nex][0/2]
  • d p [ u ] [ 1 ] = i d v < P d p [ v ] [ 0 / 1 ] i d w > P d p [ w ] [ 0 / 2 ] dp[u][1]=\prod_{id_v<P}dp[v][0/1]\cdot\prod_{id_w>P}dp[w][0/2]
    • 父亲边之前的儿子pre,要求保持u为0,所以可以选择 d p [ p r e ] [ 0 / 1 ] dp[pre][0/1]
    • 父亲边之后的儿子nex,要求u已经是1,所以可以选择 d p [ n e x ] [ 0 / 2 ] dp[nex][0/2]
  • d p [ u ] [ 2 ] dp[u][2]
    • 若选,在父亲边之后选择一个儿子,其他与 d p [ u ] [ 0 ] dp[u][0] 相同
    • 若不选,则所有儿子保持u为0,即 d p [ v ] [ 0 / 1 ] dp[v][0/1]

d p [ p r e ] [ 0 / 1 ] , d p [ n e x ] [ 0 / 2 ] \prod dp[pre][0/1],\prod dp[nex][0/2] 可以用前缀和后缀处理(和的乘积)

找父亲边的前后可以依据链式前向星的下标来解决。

代码:

/*
 *  Author : Jk_Chen
 *    Date : 2020-02-24-12.56.51
 */
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define rep(i,a,b) for(int i=(int)(a);i<=(int)(b);i++)
#define per(i,a,b) for(int i=(int)(a);i>=(int)(b);i--)
#define mmm(a,b) memset(a,b,sizeof(a))
#define pb push_back
#define pill pair<int, int>
#define fi first
#define se second
#define debug(x) cerr<<#x<<" = "<<x<<'\n'
const LL mod=998244353;
const int maxn=2e5+9;
const int inf=0x3f3f3f3f;
LL rd(){ LL ans=0; char last=' ',ch=getchar();
    while(!(ch>='0' && ch<='9'))last=ch,ch=getchar();
    while(ch>='0' && ch<='9')ans=ans*10+ch-'0',ch=getchar();
    if(last=='-')ans=-ans; return ans;
}
#define rd rd()
/*_________________________________________________________begin*/

#define rep_e(i,p,u) for(int i=head[p],u=to[i];i;i=nex[i],u=to[i])
int head[maxn],to[maxn<<1],nex[maxn<<1],now;
void add(int a,int b){
    nex[++now]=head[a];head[a]=now;to[now]=b;
}
void init_edge(){
    memset(head,0,sizeof head);
    now=0;
}
/*_________________________________________________________edge*/

LL dp[maxn][3];
pill Ti[maxn];
LL sum[maxn][2];

LL add(LL a,LL b){
    return (a+b)%mod;
}
LL mul(LL a,LL b){
    return (a*b)%mod;
}

void dfs(int p,int fa,int id_fa){
    int cnt=0;
    rep_e(i,p,u){
        if(u==fa)continue;
        dfs(u,p,i);
        cnt++;
    }
    int tmp=cnt;
    rep_e(i,p,u){
        if(u==fa)continue;
        Ti[tmp--]={i,u};
    }
    sum[0][0]=sum[cnt+1][1]=1;
    rep(i,1,cnt){
        sum[i][0]=mul(add(dp[Ti[i].se][0],dp[Ti[i].se][1]),sum[i-1][0]);
    }
    per(i,cnt,1){
        sum[i][1]=mul(add(dp[Ti[i].se][0],dp[Ti[i].se][2]),sum[i+1][1]);
    }
    int r=0;
    rep(i,1,cnt)
        if(Ti[i].fi<id_fa)
            r=i;
        else
            break;
    // dp 0
    rep(i,1,r){
        LL val=dp[Ti[i].se][2]*sum[i-1][0]%mod*sum[i+1][1]%mod;
        dp[p][0]=add(dp[p][0],val);
    }
    // dp 1
    if(id_fa==1e9)
        dp[p][1]=0;
    else
        dp[p][1]=sum[r][0]*sum[r+1][1]%mod;
    // dp 2
    rep(i,r+1,cnt){
        LL val=dp[Ti[i].se][2]*sum[i-1][0]%mod*sum[i+1][1]%mod;
        dp[p][2]=add(dp[p][2],val);
    }
    dp[p][2]=add(dp[p][2],sum[cnt][0]);
}

int main(){
    int n=rd;
    rep(i,1,n-1){
        int a=rd,b=rd;
        add(a,b);add(b,a);
    }
    dfs(1,-1,1e9);
    printf("%lld\n",add(add(dp[1][0],dp[1][1]),dp[1][2]));
    return 0;
}

/*_________________________________________________________end*/

发布了773 篇原创文章 · 获赞 345 · 访问量 20万+

猜你喜欢

转载自blog.csdn.net/jk_chen_acmer/article/details/104440986