[Nowcoder 2018ACM多校第三场B] Expected Number of Nodes

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/81256390

题目大意:
给你一颗节点数为n的树, 选定一个子集
如果一个未被选择的点度数为1, 该点被删除
如果一个未被选择的点度数为2, 该点被删除, 并将其相连的两个点连接。
求对每一个k, 选定的k子集后期望剩下的点数, 模 10 9 + 7 ( n 5000 )

题目思路:
考虑选定了一个子集大小k后, 考虑每一个点的贡献
对于一个度数小于等于2的点, 他要留下来(对答案的分子产生贡献1)的情况只有他自己被选中, 否则一定会被删, 故有C(n - 1, k - 1)种。
对于一个度数大于2的点, 他要被删掉的情况, 当且仅当所有k个点都选在他的某两个孩子子树内。 这种情况下就是, 其他子树的点会从叶子开始一直删下来删光, 然后他只剩下了度数2, 也被删掉了。
设它有m个孩子, 每个孩子有sz[i], 则被删掉的方案数是

i < j C ( s z [ i ] + s z [ j ] , k ) ( m 2 ) i C ( s z [ i ] , k )

后面那一项是减去算重的部分, 从第一项的式子中可以看出,k个点都选在同一个子树内的情况, 即每个C(sz[i], k)会被算(m-1)次, 故要减去(m-2)个。
然后度数大于2点的贡献就是C(n, k)减去被删掉的方案了。

最后考虑对于每个k来求答案
对于第一种情况, 我们可以预先数出有多少个度数小于等于2的点, 然后乘以C(n - 1, k - 1)即可。
对于第二种情况, 同样可以预处理出每个i<=n, C(i, k)前面的系数, 然后O(n)的算一遍即可。
故总的复杂度是O(n^2)的。

PS:
在dfs部分, 每个节点虽然是O(孩子个数^ 2)的枚举计算, 但是总的复杂度依然是O(n^2), 这种复杂度在树形dp中也很常见。

Code:

#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

#define ll long long
#define db double
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define ls (x << 1)
#define rs ((x << 1) | 1)
#define mid ((l + r) >> 1)

using namespace std;

const int N = (int)5050;
const int mo = (int)1e9 + 7;

int n;
int cnt, lst[N], nxt[N * 2], to[N * 2];
int deg[N], sz[N], tot; ll tim[N], C[N][N];

ll pw(ll x, ll k){
    ll ret = 1;
    while (k){
        if (k & 1) ret = ret * x % mo;
        x = x * x % mo;
        k >>= 1;
    }
    return ret;
}

void add(int u, int v){
    nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v;
    nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u;
}

void dfs(int u, int fa){
    sz[u] = 1;
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
    if (deg[u] <= 2) tot ++;
    else{
        for (int j = lst[u]; j; j = nxt[j]){
            int v1 = to[j];
            if (v1 == fa) continue;
            for (int k = lst[u]; k && to[k] != v1; k = nxt[k]){
                int v2 = to[k];
                if (v2 == fa) continue;

                (tim[sz[v1] + sz[v2]] -= 1) %= mo;
            }

            (tim[sz[v1]] += (deg[u] - 2)) %= mo;
        }

        if (u != 1){
            for (int j = lst[u]; j; j = nxt[j]){
                int v = to[j];
                if (v == fa) continue;

                (tim[sz[v] + n - sz[u]] -= 1) %= mo;
            }

            (tim[n - sz[u]] += (deg[u] - 2)) %= mo;
        }
    }
}

int main(){
    scanf("%d", &n);
    for (int i = 1, u, v; i < n; i ++){
        scanf("%d %d", &u, &v);
        add(u, v); deg[u] ++, deg[v] ++;
    }

    dfs(1, 0);

    C[0][0] = 1;
    for (int i = 1; i < N; i ++){
        C[i][0] = 1;
        for (int j = 1; j <= i; j ++)
            (C[i][j] = C[i - 1][j] + C[i - 1][j - 1]) %= mo;
    }

    for (int k = 1; k <= n; k ++){
        ll ans = 0;
        (ans += tot * C[n - 1][k - 1] % mo) %= mo;
        (ans += (n - tot) * C[n][k] % mo) %= mo;

        for (int i = 1; i <= n; i ++)
            (ans += tim[i] * C[i][k]) %= mo;

        if (ans < 0) ans += mo;

        ans = ans * pw(C[n][k], mo - 2) % mo;

        printf("%lld\n", ans);
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/u013578420/article/details/81256390