牛客挑战赛30 小G砍树 树形dp

小G砍树

dfs两次, dp出每个点作为最后一个点的方案数。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ull unsigned long long

using namespace std;

const int N = 1e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-8;
const double PI = acos(-1);

int n, ans, son[N], dp[N];
vector<int> G[N];

int F[N], Finv[N], inv[N];

void init() {
    inv[1] = F[0] = Finv[0] = 1;
    for(int i = 2; i < N; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    for(int i = 1; i < N; i++) F[i] = 1ll * F[i - 1] * i % mod;
    for(int i = 1; i < N; i++) Finv[i] = 1ll * Finv[i - 1] * inv[i] % mod;
}

int Power(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod; b >>= 1;
    }
    return ans;
}

void dfs(int u, int fa) {
    dp[u] = 1;
    for(auto& v : G[u]) {
        if(v == fa) continue;
        dfs(v, u);
        dp[u] = 1ll * dp[u] * dp[v] % mod;
        dp[u] = 1ll * dp[u] * Finv[son[v]] % mod;
        son[u] += son[v];
    }
    dp[u] = 1ll * dp[u] * F[son[u]] % mod;
    son[u]++;
}

void getAns(int u, int fa, int tmp) {
    ans = (ans + 1ll * tmp * dp[u] % mod * F[n - 1] % mod * Finv[son[u] - 1] % mod * Finv[n - son[u]] % mod) % mod;
    tmp = 1ll * tmp * Finv[n - son[u]] % mod;
    for(auto& v : G[u]) {
        if(v == fa) continue;
        tmp = 1ll * tmp * dp[v] % mod;
        tmp = 1ll * tmp * Finv[son[v]] % mod;
    }
    for(auto& v : G[u]) {
        if(v == fa) continue;
        int nxttmp = 1ll * tmp * F[son[v]] % mod * F[n - son[v] - 1] % mod * Power(dp[v], mod - 2) % mod;
        getAns(v, u, nxttmp);
    }
}

int main() {
    init();
    scanf("%d", &n);
    for(int i = 2; i <= n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    getAns(1, 0, 1);
    printf("%d\n", ans);
    return 0;
}

/*
*/

猜你喜欢

转载自www.cnblogs.com/CJLHY/p/10509993.html
今日推荐