bzoj 5314 [Jsoi2018]潜入行动 树形dp

题面

题目传送门

解法

细节多到烦死人的树形dp
其实想清楚了自然就没什么问题了

  • 树形背包应该是比较显然的,先考虑如何设计状态
  • f [ i ] [ j ] [ 0 / 1 ] [ 0 / 1 ] 表示点 i 所在的子树中选取 j 个点,是否装有通信设备,是否被其子节点监控的方案数
  • 假设当前我们做到的节点为 x ,枚举到 x 的其中一个儿子为 y ,假设 x 已经被遍历到的子树连 x 已经选取了 i 个点, y 子树中选 j 个点,考虑如何转移
  • 考虑如何求 f [ x ] [ i + j ] [ 0 ] [ 0 ] ,显然,这个就表示 x 这个点啥都没装,且只能被父节点控制,那么我们就可以十分轻松地得到以下转移方程: f [ x ] [ i + j ] [ 0 ] [ 0 ] + = f [ x ] [ i ] [ 0 ] [ 0 ] × f [ y ] [ j ] [ 0 ] [ 1 ] ,因为如果 y 没有被其子节点控制的话,那么这个方案显然是不合法的
  • 考虑如何求 f [ x ] [ i + j ] [ 0 ] [ 1 ] ,显然只要 x 的任意一个儿子装有通信系统即可,而且如果 y 没有装的话就必须被它的子节点控制,那么我们就可以得到如下转移方程: f [ x ] [ i + j ] [ 0 ] [ 1 ] + = f [ x ] [ i ] [ 0 ] [ 0 ] × f [ y ] [ j ] [ 1 ] [ 1 ] + f [ x ] [ i ] [ 0 ] [ 1 ] × ( f [ y ] [ j ] [ 0 ] [ 1 ] + f [ y ] [ j ] [ 1 ] [ 1 ] )
  • 考虑如何求 f [ x ] [ i + j ] [ 1 ] [ 0 ] ,因为 x 没有被它的子节点控制,那么所有儿子也不应该装通信系统,所以就可以得到转移方程: f [ x ] [ i + j ] [ 1 ] [ 0 ] + = f [ x ] [ i ] [ 1 ] [ 0 ] × ( f [ y ] [ j ] [ 0 ] [ 0 ] + f [ y ] [ j ] [ 0 ] [ 1 ] )
  • 考虑如何求 f [ x ] [ i + j ] [ 1 ] [ 1 ] ,这个就没什么限制了,显然我们可以得到: f [ x ] [ i + j ] [ 1 ] [ 1 ] + = f [ x ] [ i ] [ 1 ] [ 0 ] × ( f [ y ] [ j ] [ 1 ] [ 0 ] + f [ y ] [ j ] [ 1 ] [ 1 ] ) + f [ x ] [ i + j ] [ 1 ] [ 1 ] × ( f [ y ] [ j ] [ 0 ] [ 0 ] + f [ y ] [ j ] [ 0 ] [ 1 ] + f [ y ] [ j ] [ 1 ] [ 0 ] + f [ y ] [ j ] [ 1 ] [ 1 ] )
  • 那么我们就求完了四种转移方程
  • 在树形dp的时候注意一下, i , j 的边界考虑一下,不要过多枚举,不要让复杂度退化
  • 时间复杂度: O ( n k )

【注意事项】

  • 在转移的时候不要用 f [ x ] [ i ] [ 0 ] [ 0 ] 之类的直接转移,否则可能会出现很大的问题,即在转移的时候可能会算到一些不合法的情况
  • 如果整棵树的深度已经 > k 了,那么可以直接输出 0 我就是这样苟过去的
  • 转移的时候注意long long问题

代码

#include <bits/stdc++.h>
#define Mod 1000000007
#define N 100010
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
    x = 0; int f = 1; char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
int n, K, cnt, d[N], siz[N], g[101][2][2], f[N][101][2][2];
vector <int> e[N];
void update(int &x, int y) {x = (x + y) % Mod;}
void dfs(int x, int fa) {
    siz[x] = 1, f[x][1][1][0] = f[x][0][0][0] = 1;
    d[x] = d[fa] + 1;
    if (d[x] > K) {cout << "0\n"; exit(0);}
    for (int k = 0; k < e[x].size(); k++) {
        int y = e[x][k];
        if (y == fa) continue; dfs(y, x);
        for (int i = 0; i <= min(siz[x], K); i++) {
            g[i][0][0] = f[x][i][0][0], f[x][i][0][0] = 0;
            g[i][0][1] = f[x][i][0][1], f[x][i][0][1] = 0;
            g[i][1][0] = f[x][i][1][0], f[x][i][1][0] = 0;
            g[i][1][1] = f[x][i][1][1], f[x][i][1][1] = 0;
        }
        for (int i = 0; i <= min(siz[x], K); i++)
            for (int j = 0; j <= min(siz[y], K - i); j++) {
                update(f[x][i + j][0][0], 1ll * g[i][0][0] * f[y][j][0][1] % Mod);
                update(f[x][i + j][0][1], (1ll * g[i][0][0] * f[y][j][1][1] % Mod + 1ll * g[i][0][1] * (f[y][j][0][1] + f[y][j][1][1]) % Mod) % Mod);
                update(f[x][i + j][1][0], 1ll * g[i][1][0] * (f[y][j][0][1] + f[y][j][0][0]) % Mod);
                update(f[x][i + j][1][1], (1ll * g[i][1][0] * (f[y][j][1][0] + f[y][j][1][1]) % Mod));
                update(f[x][i + j][1][1], 1ll * g[i][1][1] * ((1ll * f[y][j][0][0] + f[y][j][0][1] + f[y][j][1][0] + f[y][j][1][1]) % Mod) % Mod);
            }
        siz[x] += siz[y];
    }
}
int main() {
    read(n), read(K);
    for (int i = 1; i < n; i++) {
        int x, y; read(x), read(y);
        e[x].push_back(y), e[y].push_back(x);
    }
    dfs(1, 0); cout << (f[1][K][0][1] + f[1][K][1][1]) % Mod << "\n";
    return 0;
}

猜你喜欢

转载自blog.csdn.net/emmmmmmmmm/article/details/81987486