Codeforces 1276D/1259G Tree Elimination (树形DP)

题目链接

http://codeforces.com/contest/1276/problem/D

题解

我什么DP都不会做,吃枣药丸……
\(f_{u,j}\)表示\(u\)子树内,\(j=0\)要求\(u\)点在轮到其父边之前被删,\(j=1\)要求\(u\)点被其父边删掉,\(j=2\)要求\(u\)点在其父边之后被删或者最后没有被删。
转移: 设儿子有\(s\)个,分别为\(v_1,v_2,...,v_s\), 且按边的编号从小到大排序,父边编号位于\(d\)\((d+1)\)之间。
枚举被哪条边删除。
\[f_{u,0}=\sum^d_{i=1}(\prod^{i-1}_{j=1}(f_{v_j,0}+f_{v_j,1})\cdot f_{v_i,2}\cdot \prod^s_{j=i+1}(f_{v_j,0}+f_{v_j,2}))\]
\[f_{u,1}=\prod^d_{i=1}(f_{v_j,0}+f_{v_j,1})\cdot \prod^s_{i=d+1}(f_{v_j,0}+f_{v_j,2})\]
\[f_{u,2}=\sum^s_{i=d+1}(\prod^{i-1}_{j=1}(f_{v_j,0}+f_{v_j,1})\cdot f_{v_i,2}\cdot \prod^s_{j=i+1}(f_{v_j,0}+f_{v_j,2}))+\prod^{s}_{i=1}(f_{v_j,0}+f_{v_j,1})\]
维护前后缀积即可。
时间复杂度\(O(n)\).

代码

#include<bits/stdc++.h>
#define llong long long
#define pii pair<int,int>
#define mkpr make_pair
using namespace std;

inline int read()
{
    int x = 0,f = 1; char ch = getchar();
    for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
    for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
    return x*f;
}

const int N = 2e5;
const int P = 998244353;
vector<pii> adj[N+3];
int fa[N+3],fae[N+3];
llong aux1[N+3],aux2[N+3];
llong f[N+3][3];
int n,en;

void dfs(int u)
{
    sort(adj[u].begin(),adj[u].end()); int faid = -1,adjn = adj[u].size();
    for(int i=0; i<adj[u].size(); i++)
    {
        int o = adj[u][i].first,v = adj[u][i].second;
        if(v==fa[u]) {faid = i; continue;} fa[v] = u,fae[v] = o;
        dfs(v);
    }
    aux1[0] = 1ll;
    for(int i=0; i<adj[u].size(); i++)
    {
        int v = adj[u][i].second; if(v==fa[u]) {aux1[i+1] = aux1[i]; continue;}
        aux1[i+1] = aux1[i]*(f[v][0]+f[v][1])%P;
    }
    aux2[adj[u].size()+1] = 1ll;
    for(int i=(int)adj[u].size()-1; i>=0; i--)
    {
        int v = adj[u][i].second; if(v==fa[u]) {aux2[i+1] = aux2[i+2]; continue;}
        aux2[i+1] = aux2[i+2]*(f[v][0]+f[v][2])%P;
    }
    f[u][0] = 0ll;
    for(int i=0; i<faid; i++)
    {
        int v = adj[u][i].second;
        llong tmp = aux1[i]*f[v][2]%P*aux2[i+2]%P; f[u][0] = (f[u][0]+tmp)%P;
    }
    if(faid!=-1) {f[u][1] = aux1[faid]*aux2[faid+2]%P;}
    f[u][2] = 0ll;
    for(int i=faid+1; i<adj[u].size(); i++)
    {
        int v = adj[u][i].second;
        llong tmp = aux1[i]*f[v][2]%P*aux2[i+2]%P; f[u][2] = (f[u][2]+tmp)%P;
    }
    f[u][2] = (f[u][2]+aux1[adjn])%P;
}

int main()
{
    scanf("%d",&n);
    for(int i=1; i<n; i++)
    {
        int u,v; scanf("%d%d",&u,&v);
        adj[u].push_back(mkpr(i,v)); adj[v].push_back(mkpr(i,u));
    }
    dfs(1);
    printf("%I64d\n",f[1][2]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suncongbo/p/12072371.html